hjkim00's picture
Upload TestTime-RLVR-v2 from Full-pipeline-relative_0827 branch
f50dc54 verified
"""
TestTime Task Generator
AZR ์ถ”๋ก ์šฉ ํ”„๋กฌํ”„ํŠธ ๊ธฐ๋ฐ˜ Induction/Deduction/Abduction ํƒœ์Šคํฌ ์ƒ์„ฑ
์š”๊ตฌ์‚ฌํ•ญ 3: "AZR์ฒ˜๋Ÿผ ํ…œํ”Œ๋ฆฟ์„ ํ™œ์šฉํ•˜์—ฌ induction, deduction, abduction ๋ฌธ์ œ๋ฅผ ์ƒ์„ฑ"
"""
from typing import Dict, List, Any, Optional, Tuple
import random
from .config import TestTimeConfig
from .logger import TestTimeLogger
# AZR ์ถ”๋ก ์šฉ ํ”„๋กฌํ”„ํŠธ ์ง์ ‘ ์‚ฌ์šฉ
from ..data_construction.prompts import get_code_problem_predictor_prompt
from .solution_generator import InitialSolutionGenerator
class TestTimeTaskGenerator:
"""IPO ํŠธ๋ฆฌํ”Œ์—์„œ 3์ข… ํƒœ์Šคํฌ ์ƒ์„ฑ"""
def __init__(self, config: TestTimeConfig, logger: Optional[TestTimeLogger] = None):
self.config = config
self.logger = logger or TestTimeLogger()
# AZR ์ถ”๋ก ์šฉ ํ”„๋กฌํ”„ํŠธ ์ง์ ‘ ์‚ฌ์šฉ (get_code_problem_predictor_prompt)
# ํ•จ์ˆ˜ ์ฝ”๋“œ ์ •๋ฆฌ์šฉ solution generator ์ธ์Šคํ„ด์Šค ์ƒ์„ฑ
self.solution_generator = InitialSolutionGenerator(None, None, config, logger)
def generate_tasks(self, ipo_triples: List[Dict[str, Any]],
problem_id: str, round_num: int = 1) -> Dict[str, List[Dict[str, Any]]]:
"""IPO ํŠธ๋ฆฌํ”Œ์—์„œ 3์ข… ํƒœ์Šคํฌ ์ƒ์„ฑ (๊ฐ ํŠธ๋ฆฌํ”Œ๋งˆ๋‹ค 3๊ฐ€์ง€ ํƒœ์Šคํฌ ๋ชจ๋‘ ์ƒ์„ฑ)"""
self.logger.log_info(f"๐ŸŽฏ Generating tasks for {problem_id} from {len(ipo_triples)} triples")
# ๐Ÿ”ง ์ˆ˜์ •: ๋ถ„๋ฐฐ ๋กœ์ง ์ œ๊ฑฐ, ๊ฐ IPO ํŠธ๋ฆฌํ”Œ์—์„œ 3๊ฐ€์ง€ ํƒœ์Šคํฌ ๋ชจ๋‘ ์ƒ์„ฑ
induction_tasks = []
deduction_tasks = []
abduction_tasks = []
for i, triple in enumerate(ipo_triples):
# ๊ฐ ํŠธ๋ฆฌํ”Œ์—์„œ induction ํƒœ์Šคํฌ ์ƒ์„ฑ
induction_task = self._generate_single_induction_task(triple, i, problem_id, round_num)
if induction_task:
induction_tasks.append(induction_task)
# ๊ฐ ํŠธ๋ฆฌํ”Œ์—์„œ deduction ํƒœ์Šคํฌ ์ƒ์„ฑ
deduction_task = self._generate_single_deduction_task(triple, i, problem_id, round_num)
if deduction_task:
deduction_tasks.append(deduction_task)
# ๊ฐ ํŠธ๋ฆฌํ”Œ์—์„œ abduction ํƒœ์Šคํฌ ์ƒ์„ฑ
abduction_task = self._generate_single_abduction_task(triple, i, problem_id, round_num)
if abduction_task:
abduction_tasks.append(abduction_task)
all_tasks = {
'induction': induction_tasks,
'deduction': deduction_tasks,
'abduction': abduction_tasks
}
# ๋กœ๊น…
task_counts = {k: len(v) for k, v in all_tasks.items()}
total_generated = sum(task_counts.values())
self.logger.log_info(f"โœ… Generated {len(induction_tasks)} induction, {len(deduction_tasks)} deduction, {len(abduction_tasks)} abduction tasks")
self.logger.log_task_generation(
problem_id,
induction_tasks,
deduction_tasks,
abduction_tasks
)
return all_tasks
def _generate_single_induction_task(self, triple: Dict[str, Any], index: int, problem_id: str, round_num: int) -> Optional[Dict[str, Any]]:
"""๋‹จ์ผ IPO ํŠธ๋ฆฌํ”Œ์—์„œ induction ํƒœ์Šคํฌ ์ƒ์„ฑ"""
try:
# ์ž…๋ ฅ-์ถœ๋ ฅ ์Œ ์ค€๋น„
# ํ‰๊ฐ€๋ฅผ ์œ„ํ•ด์„œ๋Š” ์‹ค์ œ ์ธ์ž(triple['input'])๋ฅผ ์‚ฌ์šฉ
input_output_pairs = [(triple['input'], triple['actual_output'])]
# ํ‘œ์‹œ์šฉ์œผ๋กœ๋Š” full_input_str ์‚ฌ์šฉ
display_input = triple.get('full_input_str', triple['input'])
# ๐Ÿ”ง ์ˆ˜์ •: cleanํ•œ ํ•จ์ˆ˜ ์ฝ”๋“œ๋งŒ ์ถ”์ถœ (test case ์ œ๊ฑฐ)
clean_program = self._extract_clean_function_code(triple['program'])
# ๋งค๊ฐœ๋ณ€์ˆ˜๋กœ ๋ฐ›์€ problem_id ์‚ฌ์šฉ (AZR ํ†ตํ•ฉ์šฉ)
original_problem_id = triple.get('id', '').split('_triple_')[0] # ์›๋ณธ ์ถ”์ถœ ๋กœ์ง ๋ณด์กด
# HumanEval์ธ ๊ฒฝ์šฐ ํŠน๋ณ„ ์ฒ˜๋ฆฌ
if 'HumanEval' in problem_id:
# ์›๋ณธ ํ”„๋กœ๊ทธ๋žจ์—์„œ ํ•จ์ˆ˜ ์„ค๋ช… ์ถ”์ถœ (doctest ์˜ˆ์‹œ๊ฐ€ ์žˆ๋Š” ์›๋ณธ์—์„œ)
extracted_message = self._extract_function_description(triple['program'])
if not extracted_message:
extracted_message = "Find the function that produces these outputs from these inputs."
else:
# MBPP๋Š” ๊ธฐ์กด ๋ฐฉ์‹ ์œ ์ง€
extracted_message = InitialSolutionGenerator.extract_docstring_from_function(clean_program)
# ์‚ฌ์šฉ์ž ์ •์˜: input_output_pairs + message โ†’ program
# ํ”„๋กฌํ”„ํŠธ์šฉ์œผ๋กœ๋Š” display ์ž…๋ ฅ ์‚ฌ์šฉ
display_pairs = [(display_input, triple['actual_output'])]
azr_prompt = get_code_problem_predictor_prompt(
problem_type='code_f',
snippet=clean_program, # ๐Ÿ”ง ์ˆ˜์ •: cleanํ•œ ์ฝ”๋“œ ์‚ฌ์šฉ
input_output_pairs=display_pairs,
message=extracted_message
)
# AZR ๋ฉ”ํƒ€๋ฐ์ดํ„ฐ ์ƒ์„ฑ
source_program_id = triple.get('source_program_id', f'program_{index//3}')
ipo_index = triple.get('ipo_index', index % 3)
task = {
'task_id': f'induction_{index}',
'task_type': 'induction',
'triple_id': triple['id'],
'source_program_id': source_program_id, # ๐Ÿ†• ์ถ”๊ฐ€
'ipo_index': ipo_index, # ๐Ÿ†• ์ถ”๊ฐ€
'ipo_triple': { # ๐Ÿ†• ์ถ”๊ฐ€
'input': triple['input'],
'output': triple['actual_output'],
'program': triple['program']
},
'prompt': azr_prompt,
'expected_solution': clean_program, # ๐Ÿ”ง ์ˆ˜์ •: cleanํ•œ ์ฝ”๋“œ ์‚ฌ์šฉ
'evaluation_data': {
'input_output_pairs': input_output_pairs, # ํ‰๊ฐ€์šฉ์œผ๋กœ๋Š” ์‹ค์ œ ์ธ์ž ์‚ฌ์šฉ
'original_function': triple['program']
},
# ๐Ÿ†• AZR ํ•™์Šต์šฉ ๋ฉ”ํƒ€๋ฐ์ดํ„ฐ
'uid': f"{problem_id}_round_{round_num}_induction_{index}",
'ipo_group_id': f"{problem_id}_program_{source_program_id}_ipo_{ipo_index}",
'original_problem_id': problem_id,
'round': round_num,
'extra_info': {'metric': 'code_f'}, # induction task๋Š” code_f
'basic_accuracy': 0.0, # ์ดˆ๊ธฐ๊ฐ’, evaluation์—์„œ ์—…๋ฐ์ดํŠธ๋จ
'ground_truth': clean_program # AZR parquet ํ˜•์‹์—์„œ ์‚ฌ์šฉ
}
return task
except Exception as e:
self.logger.log_error(f"Failed to generate induction task for triple {triple.get('id', 'unknown')}: {e}")
return None
def _generate_single_deduction_task(self, triple: Dict[str, Any], index: int, problem_id: str, round_num: int) -> Optional[Dict[str, Any]]:
"""๋‹จ์ผ IPO ํŠธ๋ฆฌํ”Œ์—์„œ deduction ํƒœ์Šคํฌ ์ƒ์„ฑ"""
try:
# ๋งค๊ฐœ๋ณ€์ˆ˜๋กœ ๋ฐ›์€ problem_id ์‚ฌ์šฉ (AZR ํ†ตํ•ฉ์šฉ)
original_problem_id = triple.get('id', '').split('_triple_')[0] # ์›๋ณธ ์ถ”์ถœ ๋กœ์ง ๋ณด์กด
# HumanEval์ธ ๊ฒฝ์šฐ doctest ์˜ˆ์‹œ ์ œ๊ฑฐ
if 'HumanEval' in original_problem_id:
clean_program = self._remove_doctest_examples(triple['program'])
else:
# MBPP๋Š” ๊ธฐ์กด ๋ฐฉ์‹ ์œ ์ง€
clean_program = self._extract_clean_function_code(triple['program'])
# ์‚ฌ์šฉ์ž ์ •์˜: program + input โ†’ output
azr_prompt = get_code_problem_predictor_prompt(
problem_type='code_o', # ํ”„๋กœ๊ทธ๋žจ+์ž…๋ ฅโ†’์ถœ๋ ฅ
snippet=clean_program, # ๐Ÿ”ง ์ˆ˜์ •: cleanํ•œ ์ฝ”๋“œ ์‚ฌ์šฉ
input_args=triple['input']
)
# AZR ๋ฉ”ํƒ€๋ฐ์ดํ„ฐ ์ƒ์„ฑ
source_program_id = triple.get('source_program_id', f'program_{index//3}')
ipo_index = triple.get('ipo_index', index % 3)
task = {
'task_id': f'deduction_{index}',
'task_type': 'deduction',
'triple_id': triple['id'],
'source_program_id': source_program_id, # ๐Ÿ†• ์ถ”๊ฐ€
'ipo_index': ipo_index, # ๐Ÿ†• ์ถ”๊ฐ€
'ipo_triple': { # ๐Ÿ†• ์ถ”๊ฐ€
'input': triple['input'],
'output': triple['actual_output'],
'program': triple['program']
},
'prompt': azr_prompt,
'expected_solution': triple['actual_output'], # ๐Ÿ”ง ์ˆ˜์ •: expected_solution์œผ๋กœ ํ†ต์ผ
'evaluation_data': {
'function_code': clean_program, # ๐Ÿ”ง ์ˆ˜์ •: cleanํ•œ ์ฝ”๋“œ ์‚ฌ์šฉ (complete_pipeline๊ณผ ์ผ์น˜)
'test_input': triple['input'], # ๐Ÿ”ง ์ˆ˜์ •: complete_pipeline๊ณผ ์ผ์น˜
'original_function': triple['program']
},
# ๐Ÿ†• AZR ํ•™์Šต์šฉ ๋ฉ”ํƒ€๋ฐ์ดํ„ฐ
'uid': f"{problem_id}_round_{round_num}_deduction_{index}",
'ipo_group_id': f"{problem_id}_program_{source_program_id}_ipo_{ipo_index}",
'original_problem_id': problem_id,
'round': round_num,
'extra_info': {'metric': 'code_o'}, # deduction task๋Š” code_o
'basic_accuracy': 0.0, # ์ดˆ๊ธฐ๊ฐ’, evaluation์—์„œ ์—…๋ฐ์ดํŠธ๋จ
'ground_truth': triple['actual_output'] # AZR parquet ํ˜•์‹์—์„œ ์‚ฌ์šฉ
}
return task
except Exception as e:
self.logger.log_error(f"Failed to generate deduction task for triple {triple.get('id', 'unknown')}: {e}")
return None
def _generate_single_abduction_task(self, triple: Dict[str, Any], index: int, problem_id: str, round_num: int) -> Optional[Dict[str, Any]]:
"""๋‹จ์ผ IPO ํŠธ๋ฆฌํ”Œ์—์„œ abduction ํƒœ์Šคํฌ ์ƒ์„ฑ"""
try:
# ๋งค๊ฐœ๋ณ€์ˆ˜๋กœ ๋ฐ›์€ problem_id ์‚ฌ์šฉ (AZR ํ†ตํ•ฉ์šฉ)
original_problem_id = triple.get('id', '').split('_triple_')[0] # ์›๋ณธ ์ถ”์ถœ ๋กœ์ง ๋ณด์กด
# HumanEval์ธ ๊ฒฝ์šฐ doctest ์˜ˆ์‹œ ์ œ๊ฑฐ
if 'HumanEval' in original_problem_id:
clean_program = self._remove_doctest_examples(triple['program'])
else:
# MBPP๋Š” ๊ธฐ์กด ๋ฐฉ์‹ ์œ ์ง€
clean_program = self._extract_clean_function_code(triple['program'])
# ์‚ฌ์šฉ์ž ์ •์˜: program + output โ†’ input
azr_prompt = get_code_problem_predictor_prompt(
problem_type='code_i', # ํ”„๋กœ๊ทธ๋žจ+์ถœ๋ ฅโ†’์ž…๋ ฅ
snippet=clean_program, # ๐Ÿ”ง ์ˆ˜์ •: cleanํ•œ ์ฝ”๋“œ ์‚ฌ์šฉ
output=triple['actual_output'] # ๐Ÿ”ง ์ˆ˜์ •: output ํŒŒ๋ผ๋ฏธํ„ฐ ์‚ฌ์šฉ
)
# AZR ๋ฉ”ํƒ€๋ฐ์ดํ„ฐ ์ƒ์„ฑ
source_program_id = triple.get('source_program_id', f'program_{index//3}')
ipo_index = triple.get('ipo_index', index % 3)
task = {
'task_id': f'abduction_{index}',
'task_type': 'abduction',
'triple_id': triple['id'],
'source_program_id': source_program_id, # ๐Ÿ†• ์ถ”๊ฐ€
'ipo_index': ipo_index, # ๐Ÿ†• ์ถ”๊ฐ€
'ipo_triple': { # ๐Ÿ†• ์ถ”๊ฐ€
'input': triple['input'],
'output': triple['actual_output'],
'program': triple['program']
},
'prompt': azr_prompt,
'expected_solution': triple.get('full_input_str', triple['input']), # ๐Ÿ”ง ์ˆ˜์ •: ์ „์ฒด ํ•จ์ˆ˜ ํ˜ธ์ถœ ์‚ฌ์šฉ
'evaluation_data': {
'function_code': clean_program, # ๐Ÿ”ง ์ˆ˜์ •: cleanํ•œ ์ฝ”๋“œ ์‚ฌ์šฉ (complete_pipeline๊ณผ ์ผ์น˜)
'expected_output': triple['actual_output'], # ๐Ÿ”ง ์ˆ˜์ •: complete_pipeline๊ณผ ์ผ์น˜
'original_function': triple['program']
},
# ๐Ÿ†• AZR ํ•™์Šต์šฉ ๋ฉ”ํƒ€๋ฐ์ดํ„ฐ
'uid': f"{problem_id}_round_{round_num}_abduction_{index}",
'ipo_group_id': f"{problem_id}_program_{source_program_id}_ipo_{ipo_index}",
'original_problem_id': problem_id,
'round': round_num,
'extra_info': {'metric': 'code_i'}, # abduction task๋Š” code_i
'basic_accuracy': 0.0, # ์ดˆ๊ธฐ๊ฐ’, evaluation์—์„œ ์—…๋ฐ์ดํŠธ๋จ
'ground_truth': triple.get('full_input_str', triple['input']) # AZR parquet ํ˜•์‹์—์„œ ์‚ฌ์šฉ
}
return task
except Exception as e:
self.logger.log_error(f"Failed to generate abduction task for triple {triple.get('id', 'unknown')}: {e}")
return None
def generate_induction_tasks(self, ipo_triples: List[Dict[str, Any]],
count: int) -> List[Dict[str, Any]]:
"""Induction ํƒœ์Šคํฌ: ์ž…๋ ฅ-์ถœ๋ ฅ ์Œ์—์„œ ํ”„๋กœ๊ทธ๋žจ ์ถ”๋ก  (์‚ฌ์šฉ์ž ์ •์˜ ์œ ์ง€)"""
tasks = []
selected_triples = random.sample(ipo_triples, min(count, len(ipo_triples)))
for i, triple in enumerate(selected_triples):
# ์ž…๋ ฅ-์ถœ๋ ฅ ์Œ ์ค€๋น„
input_output_pairs = [(triple['input'], triple['actual_output'])]
# ๐Ÿ”ง ์ˆ˜์ •: cleanํ•œ ํ•จ์ˆ˜ ์ฝ”๋“œ๋งŒ ์ถ”์ถœ (test case ์ œ๊ฑฐ)
clean_program = self._extract_clean_function_code(triple['program'])
# LLM์ด ์ƒ์„ฑํ•œ ํ•จ์ˆ˜์—์„œ docstring ์ถ”์ถœํ•ด์„œ message๋กœ ์‚ฌ์šฉ
extracted_message = InitialSolutionGenerator.extract_docstring_from_function(clean_program)
# ์‚ฌ์šฉ์ž ์ •์˜: input_output_pairs + message โ†’ program
azr_prompt = get_code_problem_predictor_prompt(
problem_type='code_f',
snippet=clean_program, # ๐Ÿ”ง ์ˆ˜์ •: cleanํ•œ ์ฝ”๋“œ ์‚ฌ์šฉ
input_output_pairs=input_output_pairs,
message=extracted_message
)
task = {
'task_id': f'induction_{i}',
'task_type': 'induction',
'triple_id': triple['id'],
'prompt': azr_prompt,
'expected_solution': clean_program, # ๐Ÿ”ง ์ˆ˜์ •: cleanํ•œ ์ฝ”๋“œ ์‚ฌ์šฉ
'evaluation_data': {
'input_output_pairs': input_output_pairs,
'original_function': triple['program']
}
}
tasks.append(task)
return tasks
def generate_deduction_tasks(self, ipo_triples: List[Dict[str, Any]],
count: int) -> List[Dict[str, Any]]:
"""Deduction ํƒœ์Šคํฌ: ํ”„๋กœ๊ทธ๋žจ+์ž…๋ ฅ์—์„œ ์ถœ๋ ฅ ์˜ˆ์ธก (์‚ฌ์šฉ์ž ์ •์˜์— ๋งž๊ฒŒ ์ˆ˜์ •)"""
tasks = []
selected_triples = random.sample(ipo_triples, min(count, len(ipo_triples)))
for i, triple in enumerate(selected_triples):
# ๐Ÿ”ง ์ˆ˜์ •: cleanํ•œ ํ•จ์ˆ˜ ์ฝ”๋“œ๋งŒ ์ถ”์ถœ (test case ์ œ๊ฑฐ)
clean_program = self._extract_clean_function_code(triple['program'])
# ์‚ฌ์šฉ์ž ์ •์˜: program + input โ†’ output
azr_prompt = get_code_problem_predictor_prompt(
problem_type='code_o', # ํ”„๋กœ๊ทธ๋žจ+์ž…๋ ฅโ†’์ถœ๋ ฅ
snippet=clean_program, # ๐Ÿ”ง ์ˆ˜์ •: cleanํ•œ ์ฝ”๋“œ ์‚ฌ์šฉ
input_args=triple['input']
)
task = {
'task_id': f'deduction_{i}',
'task_type': 'deduction',
'triple_id': triple['id'],
'prompt': azr_prompt,
'expected_solution': triple['actual_output'],
'evaluation_data': {
'function_code': clean_program, # ๐Ÿ”ง ์ˆ˜์ •: cleanํ•œ ์ฝ”๋“œ ์‚ฌ์šฉ
'test_input': triple['input']
}
}
tasks.append(task)
return tasks
def generate_abduction_tasks(self, ipo_triples: List[Dict[str, Any]],
count: int) -> List[Dict[str, Any]]:
"""Abduction ํƒœ์Šคํฌ: ํ”„๋กœ๊ทธ๋žจ+์ถœ๋ ฅ์—์„œ ์ž…๋ ฅ ์˜ˆ์ธก (์‚ฌ์šฉ์ž ์ •์˜์— ๋งž๊ฒŒ ์ˆ˜์ •)"""
tasks = []
selected_triples = random.sample(ipo_triples, min(count, len(ipo_triples)))
for i, triple in enumerate(selected_triples):
# ๐Ÿ”ง ์ˆ˜์ •: cleanํ•œ ํ•จ์ˆ˜ ์ฝ”๋“œ๋งŒ ์ถ”์ถœ (test case ์ œ๊ฑฐ)
clean_program = self._extract_clean_function_code(triple['program'])
# ์‚ฌ์šฉ์ž ์ •์˜: program + output โ†’ input
azr_prompt = get_code_problem_predictor_prompt(
problem_type='code_i', # ํ”„๋กœ๊ทธ๋žจ+์ถœ๋ ฅโ†’์ž…๋ ฅ
snippet=clean_program, # ๐Ÿ”ง ์ˆ˜์ •: cleanํ•œ ์ฝ”๋“œ ์‚ฌ์šฉ
output=triple['actual_output']
)
task = {
'task_id': f'abduction_{i}',
'task_type': 'abduction',
'triple_id': triple['id'],
'prompt': azr_prompt,
'expected_solution': triple.get('full_input_str', triple['input']), # ๐Ÿ”ง ์ˆ˜์ •: ์ „์ฒด ํ•จ์ˆ˜ ํ˜ธ์ถœ ์‚ฌ์šฉ
'evaluation_data': {
'function_code': clean_program, # ๐Ÿ”ง ์ˆ˜์ •: cleanํ•œ ์ฝ”๋“œ ์‚ฌ์šฉ
'expected_output': triple['actual_output']
}
}
tasks.append(task)
return tasks
def _extract_clean_function_code(self, program_with_tests: str) -> str:
"""๐Ÿ”ง ์ˆ˜์ •: ํ”„๋กœ๊ทธ๋žจ์—์„œ test case์™€ assert๋ฌธ์„ ์ œ๊ฑฐํ•˜๊ณ  ์ˆœ์ˆ˜ํ•œ ํ•จ์ˆ˜ ์ฝ”๋“œ๋งŒ ์ถ”์ถœ"""
# solution_generator์˜ _extract_function_code ๋ฉ”์„œ๋“œ ์‚ฌ์šฉ
clean_code = self.solution_generator._extract_function_code(program_with_tests)
# ๋กœ๊น… (๋””๋ฒ„๊น…์šฉ)
if "assert" in program_with_tests or "# Test" in program_with_tests:
self.logger.log_info("๐Ÿงน Cleaned function code (removed test cases)")
return clean_code
def get_task_statistics(self, all_tasks: Dict[str, List[Dict[str, Any]]]) -> Dict[str, Any]:
"""ํƒœ์Šคํฌ ์ƒ์„ฑ ํ†ต๊ณ„"""
stats = {
'total_tasks': sum(len(tasks) for tasks in all_tasks.values()),
'tasks_by_type': {task_type: len(tasks) for task_type, tasks in all_tasks.items()},
'task_types': list(all_tasks.keys())
}
return stats
def _remove_doctest_examples(self, code: str) -> str:
"""HumanEval ์ฝ”๋“œ์—์„œ doctest ์˜ˆ์‹œ ์ œ๊ฑฐ"""
import re
lines = code.split('\n')
result_lines = []
in_docstring = False
docstring_indent = 0
skip_next = False
for line in lines:
stripped = line.strip()
# docstring ์‹œ์ž‘/๋ ๊ฐ์ง€
if '"""' in line or "'''" in line:
if not in_docstring:
in_docstring = True
docstring_indent = len(line) - len(line.lstrip())
result_lines.append(line)
else:
in_docstring = False
result_lines.append(line)
continue
# doctest ์˜ˆ์‹œ ๋ผ์ธ ๊ฑด๋„ˆ๋›ฐ๊ธฐ
if in_docstring:
if stripped.startswith('>>>'):
skip_next = True # ๋‹ค์Œ ๋ผ์ธ(๊ฒฐ๊ณผ)๋„ ๊ฑด๋„ˆ๋›ฐ๊ธฐ
continue
elif skip_next and stripped and not stripped.startswith('>>>'):
skip_next = False
continue
else:
skip_next = False
result_lines.append(line)
return '\n'.join(result_lines)
def _extract_function_description(self, code: str) -> str:
"""docstring์—์„œ ํ•จ์ˆ˜ ์„ค๋ช… ์ถ”์ถœ (์˜ˆ์‹œ ์ œ์™ธ)"""
import re
# ์—ฌ๋Ÿฌ ํ˜•ํƒœ์˜ docstring ๋งค์นญ
patterns = [
r'"""(.*?)"""', # triple double quotes
r"'''(.*?)'''", # triple single quotes
]
for pattern in patterns:
match = re.search(pattern, code, re.DOTALL)
if match:
description = match.group(1).strip()
# ์˜ˆ์‹œ ์ „๊นŒ์ง€์˜ ๋ชจ๋“  ์„ค๋ช… ์ถ”์ถœ
result_lines = []
lines = description.split('\n')
for line in lines:
cleaned_line = line.strip()
# >>> ์˜ˆ์‹œ๊ฐ€ ์‹œ์ž‘๋˜๋ฉด ์ค‘๋‹จ
if cleaned_line.startswith('>>>'):
break
# ๋นˆ ์ค„์ด ์•„๋‹ˆ๊ณ  ์˜ˆ์‹œ๊ฐ€ ์•„๋‹Œ ๊ฒฝ์šฐ ์ถ”๊ฐ€
if cleaned_line:
result_lines.append(cleaned_line)
# ๋ชจ๋“  ์„ค๋ช… ๋ผ์ธ์„ ๊ณต๋ฐฑ์œผ๋กœ ์—ฐ๊ฒฐ
if result_lines:
return ' '.join(result_lines)
return ""