|
""" |
|
TestTime Task Generator |
|
|
|
AZR ์ถ๋ก ์ฉ ํ๋กฌํํธ ๊ธฐ๋ฐ Induction/Deduction/Abduction ํ์คํฌ ์์ฑ |
|
์๊ตฌ์ฌํญ 3: "AZR์ฒ๋ผ ํ
ํ๋ฆฟ์ ํ์ฉํ์ฌ induction, deduction, abduction ๋ฌธ์ ๋ฅผ ์์ฑ" |
|
""" |
|
|
|
from typing import Dict, List, Any, Optional, Tuple |
|
import random |
|
|
|
from .config import TestTimeConfig |
|
from .logger import TestTimeLogger |
|
|
|
from ..data_construction.prompts import get_code_problem_predictor_prompt |
|
from .solution_generator import InitialSolutionGenerator |
|
|
|
|
|
class TestTimeTaskGenerator: |
|
"""IPO ํธ๋ฆฌํ์์ 3์ข
ํ์คํฌ ์์ฑ""" |
|
|
|
def __init__(self, config: TestTimeConfig, logger: Optional[TestTimeLogger] = None): |
|
self.config = config |
|
self.logger = logger or TestTimeLogger() |
|
|
|
|
|
|
|
self.solution_generator = InitialSolutionGenerator(None, None, config, logger) |
|
|
|
def generate_tasks(self, ipo_triples: List[Dict[str, Any]], |
|
problem_id: str, round_num: int = 1) -> Dict[str, List[Dict[str, Any]]]: |
|
"""IPO ํธ๋ฆฌํ์์ 3์ข
ํ์คํฌ ์์ฑ (๊ฐ ํธ๋ฆฌํ๋ง๋ค 3๊ฐ์ง ํ์คํฌ ๋ชจ๋ ์์ฑ)""" |
|
|
|
self.logger.log_info(f"๐ฏ Generating tasks for {problem_id} from {len(ipo_triples)} triples") |
|
|
|
|
|
induction_tasks = [] |
|
deduction_tasks = [] |
|
abduction_tasks = [] |
|
|
|
for i, triple in enumerate(ipo_triples): |
|
|
|
induction_task = self._generate_single_induction_task(triple, i, problem_id, round_num) |
|
if induction_task: |
|
induction_tasks.append(induction_task) |
|
|
|
|
|
deduction_task = self._generate_single_deduction_task(triple, i, problem_id, round_num) |
|
if deduction_task: |
|
deduction_tasks.append(deduction_task) |
|
|
|
|
|
abduction_task = self._generate_single_abduction_task(triple, i, problem_id, round_num) |
|
if abduction_task: |
|
abduction_tasks.append(abduction_task) |
|
|
|
all_tasks = { |
|
'induction': induction_tasks, |
|
'deduction': deduction_tasks, |
|
'abduction': abduction_tasks |
|
} |
|
|
|
|
|
task_counts = {k: len(v) for k, v in all_tasks.items()} |
|
total_generated = sum(task_counts.values()) |
|
|
|
self.logger.log_info(f"โ
Generated {len(induction_tasks)} induction, {len(deduction_tasks)} deduction, {len(abduction_tasks)} abduction tasks") |
|
|
|
self.logger.log_task_generation( |
|
problem_id, |
|
induction_tasks, |
|
deduction_tasks, |
|
abduction_tasks |
|
) |
|
|
|
return all_tasks |
|
|
|
def _generate_single_induction_task(self, triple: Dict[str, Any], index: int, problem_id: str, round_num: int) -> Optional[Dict[str, Any]]: |
|
"""๋จ์ผ IPO ํธ๋ฆฌํ์์ induction ํ์คํฌ ์์ฑ""" |
|
|
|
try: |
|
|
|
|
|
input_output_pairs = [(triple['input'], triple['actual_output'])] |
|
|
|
|
|
display_input = triple.get('full_input_str', triple['input']) |
|
|
|
|
|
clean_program = self._extract_clean_function_code(triple['program']) |
|
|
|
|
|
original_problem_id = triple.get('id', '').split('_triple_')[0] |
|
|
|
|
|
if 'HumanEval' in problem_id: |
|
|
|
extracted_message = self._extract_function_description(triple['program']) |
|
if not extracted_message: |
|
extracted_message = "Find the function that produces these outputs from these inputs." |
|
else: |
|
|
|
extracted_message = InitialSolutionGenerator.extract_docstring_from_function(clean_program) |
|
|
|
|
|
|
|
display_pairs = [(display_input, triple['actual_output'])] |
|
azr_prompt = get_code_problem_predictor_prompt( |
|
problem_type='code_f', |
|
snippet=clean_program, |
|
input_output_pairs=display_pairs, |
|
message=extracted_message |
|
) |
|
|
|
|
|
source_program_id = triple.get('source_program_id', f'program_{index//3}') |
|
ipo_index = triple.get('ipo_index', index % 3) |
|
|
|
task = { |
|
'task_id': f'induction_{index}', |
|
'task_type': 'induction', |
|
'triple_id': triple['id'], |
|
'source_program_id': source_program_id, |
|
'ipo_index': ipo_index, |
|
'ipo_triple': { |
|
'input': triple['input'], |
|
'output': triple['actual_output'], |
|
'program': triple['program'] |
|
}, |
|
'prompt': azr_prompt, |
|
'expected_solution': clean_program, |
|
'evaluation_data': { |
|
'input_output_pairs': input_output_pairs, |
|
'original_function': triple['program'] |
|
}, |
|
|
|
|
|
'uid': f"{problem_id}_round_{round_num}_induction_{index}", |
|
'ipo_group_id': f"{problem_id}_program_{source_program_id}_ipo_{ipo_index}", |
|
'original_problem_id': problem_id, |
|
'round': round_num, |
|
'extra_info': {'metric': 'code_f'}, |
|
'basic_accuracy': 0.0, |
|
'ground_truth': clean_program |
|
} |
|
|
|
return task |
|
|
|
except Exception as e: |
|
self.logger.log_error(f"Failed to generate induction task for triple {triple.get('id', 'unknown')}: {e}") |
|
return None |
|
|
|
def _generate_single_deduction_task(self, triple: Dict[str, Any], index: int, problem_id: str, round_num: int) -> Optional[Dict[str, Any]]: |
|
"""๋จ์ผ IPO ํธ๋ฆฌํ์์ deduction ํ์คํฌ ์์ฑ""" |
|
|
|
try: |
|
|
|
original_problem_id = triple.get('id', '').split('_triple_')[0] |
|
|
|
|
|
if 'HumanEval' in original_problem_id: |
|
clean_program = self._remove_doctest_examples(triple['program']) |
|
else: |
|
|
|
clean_program = self._extract_clean_function_code(triple['program']) |
|
|
|
|
|
azr_prompt = get_code_problem_predictor_prompt( |
|
problem_type='code_o', |
|
snippet=clean_program, |
|
input_args=triple['input'] |
|
) |
|
|
|
|
|
source_program_id = triple.get('source_program_id', f'program_{index//3}') |
|
ipo_index = triple.get('ipo_index', index % 3) |
|
|
|
task = { |
|
'task_id': f'deduction_{index}', |
|
'task_type': 'deduction', |
|
'triple_id': triple['id'], |
|
'source_program_id': source_program_id, |
|
'ipo_index': ipo_index, |
|
'ipo_triple': { |
|
'input': triple['input'], |
|
'output': triple['actual_output'], |
|
'program': triple['program'] |
|
}, |
|
'prompt': azr_prompt, |
|
'expected_solution': triple['actual_output'], |
|
'evaluation_data': { |
|
'function_code': clean_program, |
|
'test_input': triple['input'], |
|
'original_function': triple['program'] |
|
}, |
|
|
|
|
|
'uid': f"{problem_id}_round_{round_num}_deduction_{index}", |
|
'ipo_group_id': f"{problem_id}_program_{source_program_id}_ipo_{ipo_index}", |
|
'original_problem_id': problem_id, |
|
'round': round_num, |
|
'extra_info': {'metric': 'code_o'}, |
|
'basic_accuracy': 0.0, |
|
'ground_truth': triple['actual_output'] |
|
} |
|
|
|
return task |
|
|
|
except Exception as e: |
|
self.logger.log_error(f"Failed to generate deduction task for triple {triple.get('id', 'unknown')}: {e}") |
|
return None |
|
|
|
def _generate_single_abduction_task(self, triple: Dict[str, Any], index: int, problem_id: str, round_num: int) -> Optional[Dict[str, Any]]: |
|
"""๋จ์ผ IPO ํธ๋ฆฌํ์์ abduction ํ์คํฌ ์์ฑ""" |
|
|
|
try: |
|
|
|
original_problem_id = triple.get('id', '').split('_triple_')[0] |
|
|
|
|
|
if 'HumanEval' in original_problem_id: |
|
clean_program = self._remove_doctest_examples(triple['program']) |
|
else: |
|
|
|
clean_program = self._extract_clean_function_code(triple['program']) |
|
|
|
|
|
azr_prompt = get_code_problem_predictor_prompt( |
|
problem_type='code_i', |
|
snippet=clean_program, |
|
output=triple['actual_output'] |
|
) |
|
|
|
|
|
source_program_id = triple.get('source_program_id', f'program_{index//3}') |
|
ipo_index = triple.get('ipo_index', index % 3) |
|
|
|
task = { |
|
'task_id': f'abduction_{index}', |
|
'task_type': 'abduction', |
|
'triple_id': triple['id'], |
|
'source_program_id': source_program_id, |
|
'ipo_index': ipo_index, |
|
'ipo_triple': { |
|
'input': triple['input'], |
|
'output': triple['actual_output'], |
|
'program': triple['program'] |
|
}, |
|
'prompt': azr_prompt, |
|
'expected_solution': triple.get('full_input_str', triple['input']), |
|
'evaluation_data': { |
|
'function_code': clean_program, |
|
'expected_output': triple['actual_output'], |
|
'original_function': triple['program'] |
|
}, |
|
|
|
|
|
'uid': f"{problem_id}_round_{round_num}_abduction_{index}", |
|
'ipo_group_id': f"{problem_id}_program_{source_program_id}_ipo_{ipo_index}", |
|
'original_problem_id': problem_id, |
|
'round': round_num, |
|
'extra_info': {'metric': 'code_i'}, |
|
'basic_accuracy': 0.0, |
|
'ground_truth': triple.get('full_input_str', triple['input']) |
|
} |
|
|
|
return task |
|
|
|
except Exception as e: |
|
self.logger.log_error(f"Failed to generate abduction task for triple {triple.get('id', 'unknown')}: {e}") |
|
return None |
|
|
|
def generate_induction_tasks(self, ipo_triples: List[Dict[str, Any]], |
|
count: int) -> List[Dict[str, Any]]: |
|
"""Induction ํ์คํฌ: ์
๋ ฅ-์ถ๋ ฅ ์์์ ํ๋ก๊ทธ๋จ ์ถ๋ก (์ฌ์ฉ์ ์ ์ ์ ์ง)""" |
|
|
|
tasks = [] |
|
selected_triples = random.sample(ipo_triples, min(count, len(ipo_triples))) |
|
|
|
for i, triple in enumerate(selected_triples): |
|
|
|
input_output_pairs = [(triple['input'], triple['actual_output'])] |
|
|
|
|
|
clean_program = self._extract_clean_function_code(triple['program']) |
|
|
|
|
|
extracted_message = InitialSolutionGenerator.extract_docstring_from_function(clean_program) |
|
|
|
|
|
azr_prompt = get_code_problem_predictor_prompt( |
|
problem_type='code_f', |
|
snippet=clean_program, |
|
input_output_pairs=input_output_pairs, |
|
message=extracted_message |
|
) |
|
|
|
task = { |
|
'task_id': f'induction_{i}', |
|
'task_type': 'induction', |
|
'triple_id': triple['id'], |
|
'prompt': azr_prompt, |
|
'expected_solution': clean_program, |
|
'evaluation_data': { |
|
'input_output_pairs': input_output_pairs, |
|
'original_function': triple['program'] |
|
} |
|
} |
|
|
|
tasks.append(task) |
|
|
|
return tasks |
|
|
|
def generate_deduction_tasks(self, ipo_triples: List[Dict[str, Any]], |
|
count: int) -> List[Dict[str, Any]]: |
|
"""Deduction ํ์คํฌ: ํ๋ก๊ทธ๋จ+์
๋ ฅ์์ ์ถ๋ ฅ ์์ธก (์ฌ์ฉ์ ์ ์์ ๋ง๊ฒ ์์ )""" |
|
|
|
tasks = [] |
|
selected_triples = random.sample(ipo_triples, min(count, len(ipo_triples))) |
|
|
|
for i, triple in enumerate(selected_triples): |
|
|
|
clean_program = self._extract_clean_function_code(triple['program']) |
|
|
|
|
|
azr_prompt = get_code_problem_predictor_prompt( |
|
problem_type='code_o', |
|
snippet=clean_program, |
|
input_args=triple['input'] |
|
) |
|
|
|
task = { |
|
'task_id': f'deduction_{i}', |
|
'task_type': 'deduction', |
|
'triple_id': triple['id'], |
|
'prompt': azr_prompt, |
|
'expected_solution': triple['actual_output'], |
|
'evaluation_data': { |
|
'function_code': clean_program, |
|
'test_input': triple['input'] |
|
} |
|
} |
|
|
|
tasks.append(task) |
|
|
|
return tasks |
|
|
|
def generate_abduction_tasks(self, ipo_triples: List[Dict[str, Any]], |
|
count: int) -> List[Dict[str, Any]]: |
|
"""Abduction ํ์คํฌ: ํ๋ก๊ทธ๋จ+์ถ๋ ฅ์์ ์
๋ ฅ ์์ธก (์ฌ์ฉ์ ์ ์์ ๋ง๊ฒ ์์ )""" |
|
|
|
tasks = [] |
|
selected_triples = random.sample(ipo_triples, min(count, len(ipo_triples))) |
|
|
|
for i, triple in enumerate(selected_triples): |
|
|
|
clean_program = self._extract_clean_function_code(triple['program']) |
|
|
|
|
|
azr_prompt = get_code_problem_predictor_prompt( |
|
problem_type='code_i', |
|
snippet=clean_program, |
|
output=triple['actual_output'] |
|
) |
|
|
|
task = { |
|
'task_id': f'abduction_{i}', |
|
'task_type': 'abduction', |
|
'triple_id': triple['id'], |
|
'prompt': azr_prompt, |
|
'expected_solution': triple.get('full_input_str', triple['input']), |
|
'evaluation_data': { |
|
'function_code': clean_program, |
|
'expected_output': triple['actual_output'] |
|
} |
|
} |
|
|
|
tasks.append(task) |
|
|
|
return tasks |
|
|
|
def _extract_clean_function_code(self, program_with_tests: str) -> str: |
|
"""๐ง ์์ : ํ๋ก๊ทธ๋จ์์ test case์ assert๋ฌธ์ ์ ๊ฑฐํ๊ณ ์์ํ ํจ์ ์ฝ๋๋ง ์ถ์ถ""" |
|
|
|
|
|
clean_code = self.solution_generator._extract_function_code(program_with_tests) |
|
|
|
|
|
if "assert" in program_with_tests or "# Test" in program_with_tests: |
|
self.logger.log_info("๐งน Cleaned function code (removed test cases)") |
|
|
|
return clean_code |
|
|
|
def get_task_statistics(self, all_tasks: Dict[str, List[Dict[str, Any]]]) -> Dict[str, Any]: |
|
"""ํ์คํฌ ์์ฑ ํต๊ณ""" |
|
|
|
stats = { |
|
'total_tasks': sum(len(tasks) for tasks in all_tasks.values()), |
|
'tasks_by_type': {task_type: len(tasks) for task_type, tasks in all_tasks.items()}, |
|
'task_types': list(all_tasks.keys()) |
|
} |
|
|
|
return stats |
|
|
|
def _remove_doctest_examples(self, code: str) -> str: |
|
"""HumanEval ์ฝ๋์์ doctest ์์ ์ ๊ฑฐ""" |
|
import re |
|
|
|
lines = code.split('\n') |
|
result_lines = [] |
|
in_docstring = False |
|
docstring_indent = 0 |
|
skip_next = False |
|
|
|
for line in lines: |
|
stripped = line.strip() |
|
|
|
|
|
if '"""' in line or "'''" in line: |
|
if not in_docstring: |
|
in_docstring = True |
|
docstring_indent = len(line) - len(line.lstrip()) |
|
result_lines.append(line) |
|
else: |
|
in_docstring = False |
|
result_lines.append(line) |
|
continue |
|
|
|
|
|
if in_docstring: |
|
if stripped.startswith('>>>'): |
|
skip_next = True |
|
continue |
|
elif skip_next and stripped and not stripped.startswith('>>>'): |
|
skip_next = False |
|
continue |
|
else: |
|
skip_next = False |
|
|
|
result_lines.append(line) |
|
|
|
return '\n'.join(result_lines) |
|
|
|
def _extract_function_description(self, code: str) -> str: |
|
"""docstring์์ ํจ์ ์ค๋ช
์ถ์ถ (์์ ์ ์ธ)""" |
|
import re |
|
|
|
|
|
patterns = [ |
|
r'"""(.*?)"""', |
|
r"'''(.*?)'''", |
|
] |
|
|
|
for pattern in patterns: |
|
match = re.search(pattern, code, re.DOTALL) |
|
if match: |
|
description = match.group(1).strip() |
|
|
|
result_lines = [] |
|
lines = description.split('\n') |
|
for line in lines: |
|
cleaned_line = line.strip() |
|
|
|
if cleaned_line.startswith('>>>'): |
|
break |
|
|
|
if cleaned_line: |
|
result_lines.append(cleaned_line) |
|
|
|
|
|
if result_lines: |
|
return ' '.join(result_lines) |
|
|
|
return "" |