|
|
|
""" |
|
TTRLVR + AZR ํตํฉ ์์คํ
ํ
์คํธ ์คํฌ๋ฆฝํธ |
|
|
|
์ฃผ์ ์ปดํฌ๋ํธ๋ค์ ๋จ์ ํ
์คํธ ๋ฐ ํตํฉ ํ
์คํธ๋ฅผ ์ํํฉ๋๋ค: |
|
1. Task Generator ํ
์คํธ (AZR ๋ฉํ๋ฐ์ดํฐ ํฌํจ) |
|
2. Complete Pipeline ํ
์คํธ (basic_accuracy ์
๋ฐ์ดํธ) |
|
3. Data Converter ํ
์คํธ (parquet ์ ์ฅ) |
|
4. Iterative Trainer ํ
์คํธ (๋ผ์ด๋ ๊ด๋ฆฌ) |
|
5. ์ ์ฒด ํตํฉ ํ
์คํธ |
|
""" |
|
|
|
import os |
|
import sys |
|
import json |
|
import tempfile |
|
import shutil |
|
import unittest |
|
from unittest.mock import Mock, patch, MagicMock |
|
from pathlib import Path |
|
|
|
|
|
sys.path.append('/home/ubuntu/RLVR/TestTime-RLVR-v2') |
|
|
|
|
|
from absolute_zero_reasoner.testtime.config import TestTimeConfig, BenchmarkConfig |
|
from absolute_zero_reasoner.testtime.logger import TestTimeLogger |
|
from absolute_zero_reasoner.testtime.task_generator import TestTimeTaskGenerator |
|
from absolute_zero_reasoner.testtime.complete_pipeline import CompleteTestTimePipeline |
|
from test.utils.iterative_trainer import IterativeTrainer |
|
|
|
|
|
class TestTTRLVRAZRIntegration(unittest.TestCase): |
|
"""TTRLVR + AZR ํตํฉ ์์คํ
ํ
์คํธ""" |
|
|
|
def setUp(self): |
|
"""ํ
์คํธ ์ค์ """ |
|
self.config = TestTimeConfig() |
|
self.config.model_name = "Qwen/Qwen2.5-7B" |
|
self.config.max_new_tokens = 256 |
|
self.config.temperature = 0.05 |
|
|
|
self.logger = TestTimeLogger() |
|
self.test_dir = tempfile.mkdtemp() |
|
|
|
|
|
self.test_ipo_triples = [ |
|
{ |
|
'id': 'HumanEval_0_triple_0', |
|
'input': '[1, 2, 3]', |
|
'actual_output': '[2, 4, 6]', |
|
'program': 'def test_func(lst):\n return [x * 2 for x in lst]', |
|
'full_input_str': 'test_func([1, 2, 3])', |
|
'source_program_id': 'program_0', |
|
'ipo_index': 0 |
|
}, |
|
{ |
|
'id': 'HumanEval_0_triple_1', |
|
'input': '[4, 5]', |
|
'actual_output': '[8, 10]', |
|
'program': 'def test_func(lst):\n return [x * 2 for x in lst]', |
|
'full_input_str': 'test_func([4, 5])', |
|
'source_program_id': 'program_0', |
|
'ipo_index': 1 |
|
} |
|
] |
|
|
|
def tearDown(self): |
|
"""ํ
์คํธ ์ ๋ฆฌ""" |
|
if os.path.exists(self.test_dir): |
|
shutil.rmtree(self.test_dir) |
|
|
|
def test_task_generator_azr_metadata(self): |
|
"""Task Generator์ AZR ๋ฉํ๋ฐ์ดํฐ ์์ฑ ํ
์คํธ""" |
|
|
|
task_generator = TestTimeTaskGenerator(self.config, self.logger) |
|
|
|
|
|
problem_id = "HumanEval_0" |
|
round_num = 3 |
|
tasks = task_generator.generate_tasks(self.test_ipo_triples, problem_id, round_num) |
|
|
|
|
|
self.assertIn('induction', tasks) |
|
self.assertIn('deduction', tasks) |
|
self.assertIn('abduction', tasks) |
|
|
|
|
|
for task_type, task_list in tasks.items(): |
|
self.assertGreater(len(task_list), 0, f"{task_type} tasks should be generated") |
|
|
|
for task in task_list: |
|
|
|
self.assertIn('uid', task) |
|
self.assertIn('ipo_group_id', task) |
|
self.assertIn('original_problem_id', task) |
|
self.assertIn('round', task) |
|
self.assertIn('extra_info', task) |
|
self.assertIn('basic_accuracy', task) |
|
self.assertIn('ground_truth', task) |
|
|
|
|
|
self.assertEqual(task['original_problem_id'], problem_id) |
|
self.assertEqual(task['round'], round_num) |
|
self.assertEqual(task['basic_accuracy'], 0.0) |
|
self.assertIn(problem_id, task['uid']) |
|
self.assertIn(str(round_num), task['uid']) |
|
self.assertIn(task_type, task['uid']) |
|
|
|
|
|
if task_type == 'induction': |
|
self.assertEqual(task['extra_info']['metric'], 'code_f') |
|
elif task_type == 'deduction': |
|
self.assertEqual(task['extra_info']['metric'], 'code_o') |
|
elif task_type == 'abduction': |
|
self.assertEqual(task['extra_info']['metric'], 'code_i') |
|
|
|
print("โ
Task Generator AZR metadata test passed") |
|
|
|
def test_data_converter_parquet_format(self): |
|
"""๋ฐ์ดํฐ ๋ณํ๊ธฐ์ parquet ํ์ ํ
์คํธ""" |
|
|
|
|
|
mock_tasks = { |
|
'induction': [ |
|
{ |
|
'task_id': 'induction_0', |
|
'task_type': 'induction', |
|
'prompt': 'Test prompt', |
|
'uid': 'HumanEval_0_round_1_induction_0', |
|
'ipo_group_id': 'HumanEval_0_program_0_ipo_0', |
|
'source_program_id': 'program_0', |
|
'ipo_index': 0, |
|
'ipo_triple': { |
|
'input': '[1, 2, 3]', |
|
'output': '[2, 4, 6]', |
|
'program': 'def test_func(lst):\n return [x * 2 for x in lst]' |
|
}, |
|
'ground_truth': 'def test_func(lst):\n return [x * 2 for x in lst]', |
|
'extra_info': {'metric': 'code_f'}, |
|
'basic_accuracy': 1.0, |
|
'original_problem_id': 'HumanEval_0', |
|
'round': 1 |
|
} |
|
] |
|
} |
|
|
|
|
|
with patch('absolute_zero_reasoner.testtime.complete_pipeline.CompleteTestTimePipeline') as mock_pipeline: |
|
pipeline = CompleteTestTimePipeline(self.config, self.logger) |
|
|
|
|
|
output_dir = self.test_dir |
|
problem_id = "HumanEval_0" |
|
round_num = 1 |
|
|
|
saved_files = pipeline._save_azr_training_data(mock_tasks, problem_id, round_num, output_dir) |
|
|
|
|
|
self.assertIn('induction', saved_files) |
|
self.assertTrue(os.path.exists(saved_files['induction'])) |
|
|
|
|
|
import pandas as pd |
|
df = pd.read_parquet(saved_files['induction']) |
|
|
|
|
|
self.assertEqual(len(df), 1) |
|
self.assertIn('prompt', df.columns) |
|
self.assertIn('uid', df.columns) |
|
self.assertIn('ipo_group_id', df.columns) |
|
self.assertIn('ground_truth', df.columns) |
|
self.assertIn('basic_accuracy', df.columns) |
|
|
|
|
|
prompt_data = df.iloc[0]['prompt'] |
|
self.assertIsInstance(prompt_data, list) |
|
self.assertEqual(prompt_data[0]['role'], 'user') |
|
self.assertIn('content', prompt_data[0]) |
|
|
|
print("โ
Data converter parquet format test passed") |
|
|
|
def test_complete_pipeline_basic_accuracy_update(self): |
|
"""Complete Pipeline์ basic_accuracy ์
๋ฐ์ดํธ ํ
์คํธ""" |
|
|
|
|
|
with patch.multiple( |
|
'absolute_zero_reasoner.testtime.complete_pipeline.CompleteTestTimePipeline', |
|
_generate_task_response=Mock(return_value="test response"), |
|
_extract_answer_by_task_type=Mock(return_value="test answer"), |
|
_calculate_task_accuracy=Mock(return_value=0.8) |
|
): |
|
pipeline = CompleteTestTimePipeline(self.config, self.logger) |
|
|
|
|
|
mock_tasks = { |
|
'induction': [ |
|
{ |
|
'task_id': 'induction_0', |
|
'prompt': 'test prompt', |
|
'expected_solution': 'test solution', |
|
'evaluation_data': {'test': 'data'}, |
|
'basic_accuracy': 0.0 |
|
} |
|
] |
|
} |
|
|
|
|
|
evaluations = pipeline._evaluate_tasks_with_llm(mock_tasks) |
|
|
|
|
|
updated_task = mock_tasks['induction'][0] |
|
self.assertEqual(updated_task['basic_accuracy'], 0.8) |
|
|
|
|
|
self.assertIn('induction', evaluations) |
|
eval_result = evaluations['induction'][0] |
|
self.assertIn('basic_accuracy', eval_result) |
|
self.assertEqual(eval_result['basic_accuracy'], 0.8) |
|
|
|
print("โ
Complete pipeline basic_accuracy update test passed") |
|
|
|
def test_iterative_trainer_round_management(self): |
|
"""Iterative Trainer์ ๋ผ์ด๋ ๊ด๋ฆฌ ํ
์คํธ""" |
|
|
|
|
|
benchmark_config = BenchmarkConfig( |
|
name='test_benchmark', |
|
problems_path='/test/path', |
|
max_problems=None |
|
) |
|
|
|
problem_ids = ['TestProblem_1', 'TestProblem_2'] |
|
|
|
with patch.object(IterativeTrainer, '_update_pipeline_model'): |
|
with patch.object(IterativeTrainer, '_train_azr_with_round_data') as mock_train: |
|
with patch.object(IterativeTrainer, '_save_checkpoint'): |
|
with patch('absolute_zero_reasoner.testtime.complete_pipeline.CompleteTestTimePipeline') as mock_pipeline_class: |
|
|
|
|
|
mock_pipeline = Mock() |
|
mock_pipeline.run_complete_pipeline.return_value = { |
|
'success': True, |
|
'azr_training_data': {'induction': '/test/path/induction.parquet'}, |
|
'steps': { |
|
'azr_data_saving': {'total_tasks': 10} |
|
} |
|
} |
|
mock_pipeline_class.return_value = mock_pipeline |
|
|
|
|
|
mock_train.return_value = "/data/RLVR/checkpoints/ttrlvr_azr/round_1" |
|
|
|
|
|
trainer = IterativeTrainer(self.config, self.logger) |
|
trainer.checkpoint_dir = self.test_dir |
|
|
|
|
|
round_result = trainer._run_single_round(benchmark_config, problem_ids, 1) |
|
|
|
|
|
self.assertTrue(round_result['success']) |
|
self.assertEqual(len(round_result['problems']), 2) |
|
self.assertGreater(len(round_result['training_data_files']), 0) |
|
|
|
|
|
stats = round_result['stats'] |
|
self.assertEqual(stats['total_problems'], 2) |
|
self.assertEqual(stats['successful_problems'], 2) |
|
self.assertEqual(stats['failed_problems'], 0) |
|
|
|
print("โ
Iterative trainer round management test passed") |
|
|
|
def test_data_combination_and_sorting(self): |
|
"""๋ผ์ด๋ ๋ฐ์ดํฐ ํตํฉ ๋ฐ ์ ๋ ฌ ํ
์คํธ""" |
|
|
|
trainer = IterativeTrainer(self.config, self.logger) |
|
|
|
|
|
training_data_files = [ |
|
{ |
|
'problem_id': 'TestProblem_1', |
|
'files': { |
|
'induction': os.path.join(self.test_dir, 'test_induction_1.parquet') |
|
} |
|
}, |
|
{ |
|
'problem_id': 'TestProblem_2', |
|
'files': { |
|
'induction': os.path.join(self.test_dir, 'test_induction_2.parquet') |
|
} |
|
} |
|
] |
|
|
|
|
|
import pandas as pd |
|
|
|
|
|
data1 = pd.DataFrame([ |
|
{ |
|
'uid': 'TestProblem_1_round_1_induction_0', |
|
'ipo_group_id': 'TestProblem_1_program_1_ipo_2', |
|
'basic_accuracy': 0.8 |
|
} |
|
]) |
|
data1.to_parquet(training_data_files[0]['files']['induction'], index=False) |
|
|
|
|
|
data2 = pd.DataFrame([ |
|
{ |
|
'uid': 'TestProblem_2_round_1_induction_0', |
|
'ipo_group_id': 'TestProblem_2_program_0_ipo_1', |
|
'basic_accuracy': 0.6 |
|
} |
|
]) |
|
data2.to_parquet(training_data_files[1]['files']['induction'], index=False) |
|
|
|
|
|
combined_path = trainer._combine_round_data(training_data_files, 1) |
|
|
|
self.assertIsNotNone(combined_path) |
|
self.assertTrue(os.path.exists(combined_path)) |
|
|
|
|
|
combined_file = os.path.join(combined_path, 'induction.parquet') |
|
self.assertTrue(os.path.exists(combined_file)) |
|
|
|
|
|
combined_df = pd.read_parquet(combined_file) |
|
self.assertEqual(len(combined_df), 2) |
|
|
|
|
|
ipo_groups = combined_df['ipo_group_id'].tolist() |
|
self.assertEqual(ipo_groups, sorted(ipo_groups)) |
|
|
|
print("โ
Data combination and sorting test passed") |
|
|
|
|
|
class TestPerformanceAndMemory(unittest.TestCase): |
|
"""์ฑ๋ฅ ๋ฐ ๋ฉ๋ชจ๋ฆฌ ํ
์คํธ""" |
|
|
|
def setUp(self): |
|
self.config = TestTimeConfig() |
|
self.logger = TestTimeLogger() |
|
|
|
def test_memory_cleanup_between_rounds(self): |
|
"""๋ผ์ด๋ ๊ฐ ๋ฉ๋ชจ๋ฆฌ ์ ๋ฆฌ ํ
์คํธ""" |
|
|
|
with patch('absolute_zero_reasoner.testtime.complete_pipeline.CompleteTestTimePipeline') as mock_pipeline_class: |
|
mock_pipeline = Mock() |
|
mock_pipeline.model = Mock() |
|
mock_pipeline.tokenizer = Mock() |
|
mock_pipeline_class.return_value = mock_pipeline |
|
|
|
trainer = IterativeTrainer(self.config, self.logger) |
|
|
|
|
|
trainer._update_pipeline_model("/new/model/path") |
|
|
|
|
|
self.assertIsNone(trainer.complete_pipeline.model) |
|
self.assertIsNone(trainer.complete_pipeline.tokenizer) |
|
|
|
print("โ
Memory cleanup between rounds test passed") |
|
|
|
def test_checkpoint_size_and_structure(self): |
|
"""์ฒดํฌํฌ์ธํธ ํฌ๊ธฐ ๋ฐ ๊ตฌ์กฐ ํ
์คํธ""" |
|
|
|
trainer = IterativeTrainer(self.config, self.logger) |
|
test_dir = tempfile.mkdtemp() |
|
trainer.checkpoint_dir = test_dir |
|
|
|
try: |
|
|
|
training_results = { |
|
'total_rounds': 30, |
|
'rounds': { |
|
1: {'success': True, 'stats': {'total_tasks': 100}}, |
|
2: {'success': True, 'stats': {'total_tasks': 95}}, |
|
3: {'success': False, 'error': 'Test error'} |
|
} |
|
} |
|
|
|
|
|
trainer._save_checkpoint(3, "/test/model/path", training_results) |
|
|
|
|
|
checkpoint_path = os.path.join(test_dir, "checkpoint_round_3") |
|
self.assertTrue(os.path.exists(checkpoint_path)) |
|
|
|
checkpoint_file = os.path.join(checkpoint_path, "checkpoint.json") |
|
summary_file = os.path.join(checkpoint_path, "summary.txt") |
|
|
|
self.assertTrue(os.path.exists(checkpoint_file)) |
|
self.assertTrue(os.path.exists(summary_file)) |
|
|
|
|
|
with open(checkpoint_file, 'r') as f: |
|
checkpoint_data = json.load(f) |
|
|
|
self.assertEqual(checkpoint_data['round_num'], 3) |
|
self.assertEqual(checkpoint_data['model_path'], "/test/model/path") |
|
self.assertIn('training_results', checkpoint_data) |
|
|
|
|
|
checkpoint_size = os.path.getsize(checkpoint_file) |
|
self.assertLess(checkpoint_size, 1024 * 1024) |
|
|
|
finally: |
|
shutil.rmtree(test_dir) |
|
|
|
print("โ
Checkpoint size and structure test passed") |
|
|
|
|
|
def run_integration_test(): |
|
"""ํตํฉ ํ
์คํธ ์คํ""" |
|
|
|
print("๐งช TTRLVR + AZR ํตํฉ ์์คํ
ํ
์คํธ ์์") |
|
print("=" * 60) |
|
|
|
|
|
loader = unittest.TestLoader() |
|
suite = unittest.TestSuite() |
|
|
|
|
|
suite.addTests(loader.loadTestsFromTestCase(TestTTRLVRAZRIntegration)) |
|
suite.addTests(loader.loadTestsFromTestCase(TestPerformanceAndMemory)) |
|
|
|
|
|
runner = unittest.TextTestRunner(verbosity=2) |
|
result = runner.run(suite) |
|
|
|
|
|
print("\n" + "=" * 60) |
|
print("๐ ํ
์คํธ ๊ฒฐ๊ณผ ์์ฝ:") |
|
print(f" - ์คํ๋ ํ
์คํธ: {result.testsRun}") |
|
print(f" - ์ฑ๊ณต: {result.testsRun - len(result.failures) - len(result.errors)}") |
|
print(f" - ์คํจ: {len(result.failures)}") |
|
print(f" - ์ค๋ฅ: {len(result.errors)}") |
|
|
|
if result.failures: |
|
print("\nโ ์คํจํ ํ
์คํธ:") |
|
for test, traceback in result.failures: |
|
print(f" - {test}: {traceback}") |
|
|
|
if result.errors: |
|
print("\n๐ฅ ์ค๋ฅ ๋ฐ์ ํ
์คํธ:") |
|
for test, traceback in result.errors: |
|
print(f" - {test}: {traceback}") |
|
|
|
success = len(result.failures) == 0 and len(result.errors) == 0 |
|
|
|
if success: |
|
print("\n๐ ๋ชจ๋ ํ
์คํธ ์ฑ๊ณต!") |
|
else: |
|
print("\nโ ๏ธ ์ผ๋ถ ํ
์คํธ ์คํจ ๋๋ ์ค๋ฅ ๋ฐ์") |
|
|
|
return success |
|
|
|
|
|
if __name__ == '__main__': |
|
success = run_integration_test() |
|
sys.exit(0 if success else 1) |