neural-mesh / test /test_ttrlvr_azr_integration.py
hjkim00's picture
Upload TestTime-RLVR-v2 from Full-pipeline-relative_0827 branch
f50dc54 verified
#!/usr/bin/env python3
"""
TTRLVR + AZR ํ†ตํ•ฉ ์‹œ์Šคํ…œ ํ…Œ์ŠคํŠธ ์Šคํฌ๋ฆฝํŠธ
์ฃผ์š” ์ปดํฌ๋„ŒํŠธ๋“ค์˜ ๋‹จ์œ„ ํ…Œ์ŠคํŠธ ๋ฐ ํ†ตํ•ฉ ํ…Œ์ŠคํŠธ๋ฅผ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค:
1. Task Generator ํ…Œ์ŠคํŠธ (AZR ๋ฉ”ํƒ€๋ฐ์ดํ„ฐ ํฌํ•จ)
2. Complete Pipeline ํ…Œ์ŠคํŠธ (basic_accuracy ์—…๋ฐ์ดํŠธ)
3. Data Converter ํ…Œ์ŠคํŠธ (parquet ์ €์žฅ)
4. Iterative Trainer ํ…Œ์ŠคํŠธ (๋ผ์šด๋“œ ๊ด€๋ฆฌ)
5. ์ „์ฒด ํ†ตํ•ฉ ํ…Œ์ŠคํŠธ
"""
import os
import sys
import json
import tempfile
import shutil
import unittest
from unittest.mock import Mock, patch, MagicMock
from pathlib import Path
# ๊ฒฝ๋กœ ์„ค์ •
sys.path.append('/home/ubuntu/RLVR/TestTime-RLVR-v2')
# TTRLVR ๋ชจ๋“ˆ ์ž„ํฌํŠธ
from absolute_zero_reasoner.testtime.config import TestTimeConfig, BenchmarkConfig
from absolute_zero_reasoner.testtime.logger import TestTimeLogger
from absolute_zero_reasoner.testtime.task_generator import TestTimeTaskGenerator
from absolute_zero_reasoner.testtime.complete_pipeline import CompleteTestTimePipeline
from test.utils.iterative_trainer import IterativeTrainer
class TestTTRLVRAZRIntegration(unittest.TestCase):
"""TTRLVR + AZR ํ†ตํ•ฉ ์‹œ์Šคํ…œ ํ…Œ์ŠคํŠธ"""
def setUp(self):
"""ํ…Œ์ŠคํŠธ ์„ค์ •"""
self.config = TestTimeConfig()
self.config.model_name = "Qwen/Qwen2.5-7B"
self.config.max_new_tokens = 256
self.config.temperature = 0.05
self.logger = TestTimeLogger()
self.test_dir = tempfile.mkdtemp()
# ํ…Œ์ŠคํŠธ์šฉ IPO ํŠธ๋ฆฌํ”Œ ๋ฐ์ดํ„ฐ
self.test_ipo_triples = [
{
'id': 'HumanEval_0_triple_0',
'input': '[1, 2, 3]',
'actual_output': '[2, 4, 6]',
'program': 'def test_func(lst):\n return [x * 2 for x in lst]',
'full_input_str': 'test_func([1, 2, 3])',
'source_program_id': 'program_0',
'ipo_index': 0
},
{
'id': 'HumanEval_0_triple_1',
'input': '[4, 5]',
'actual_output': '[8, 10]',
'program': 'def test_func(lst):\n return [x * 2 for x in lst]',
'full_input_str': 'test_func([4, 5])',
'source_program_id': 'program_0',
'ipo_index': 1
}
]
def tearDown(self):
"""ํ…Œ์ŠคํŠธ ์ •๋ฆฌ"""
if os.path.exists(self.test_dir):
shutil.rmtree(self.test_dir)
def test_task_generator_azr_metadata(self):
"""Task Generator์˜ AZR ๋ฉ”ํƒ€๋ฐ์ดํ„ฐ ์ƒ์„ฑ ํ…Œ์ŠคํŠธ"""
task_generator = TestTimeTaskGenerator(self.config, self.logger)
# Task ์ƒ์„ฑ (round_num ํฌํ•จ)
problem_id = "HumanEval_0"
round_num = 3
tasks = task_generator.generate_tasks(self.test_ipo_triples, problem_id, round_num)
# ๊ธฐ๋ณธ ๊ตฌ์กฐ ๊ฒ€์ฆ
self.assertIn('induction', tasks)
self.assertIn('deduction', tasks)
self.assertIn('abduction', tasks)
# ๊ฐ task ํƒ€์ž…๋ณ„ ๊ฒ€์ฆ
for task_type, task_list in tasks.items():
self.assertGreater(len(task_list), 0, f"{task_type} tasks should be generated")
for task in task_list:
# AZR ๋ฉ”ํƒ€๋ฐ์ดํ„ฐ ๊ฒ€์ฆ
self.assertIn('uid', task)
self.assertIn('ipo_group_id', task)
self.assertIn('original_problem_id', task)
self.assertIn('round', task)
self.assertIn('extra_info', task)
self.assertIn('basic_accuracy', task)
self.assertIn('ground_truth', task)
# ๊ฐ’ ๊ฒ€์ฆ
self.assertEqual(task['original_problem_id'], problem_id)
self.assertEqual(task['round'], round_num)
self.assertEqual(task['basic_accuracy'], 0.0) # ์ดˆ๊ธฐ๊ฐ’
self.assertIn(problem_id, task['uid'])
self.assertIn(str(round_num), task['uid'])
self.assertIn(task_type, task['uid'])
# task ํƒ€์ž…๋ณ„ metric ๊ฒ€์ฆ
if task_type == 'induction':
self.assertEqual(task['extra_info']['metric'], 'code_f')
elif task_type == 'deduction':
self.assertEqual(task['extra_info']['metric'], 'code_o')
elif task_type == 'abduction':
self.assertEqual(task['extra_info']['metric'], 'code_i')
print("โœ… Task Generator AZR metadata test passed")
def test_data_converter_parquet_format(self):
"""๋ฐ์ดํ„ฐ ๋ณ€ํ™˜๊ธฐ์˜ parquet ํ˜•์‹ ํ…Œ์ŠคํŠธ"""
# Mock task ๋ฐ์ดํ„ฐ ์ƒ์„ฑ
mock_tasks = {
'induction': [
{
'task_id': 'induction_0',
'task_type': 'induction',
'prompt': 'Test prompt',
'uid': 'HumanEval_0_round_1_induction_0',
'ipo_group_id': 'HumanEval_0_program_0_ipo_0',
'source_program_id': 'program_0',
'ipo_index': 0,
'ipo_triple': {
'input': '[1, 2, 3]',
'output': '[2, 4, 6]',
'program': 'def test_func(lst):\n return [x * 2 for x in lst]'
},
'ground_truth': 'def test_func(lst):\n return [x * 2 for x in lst]',
'extra_info': {'metric': 'code_f'},
'basic_accuracy': 1.0,
'original_problem_id': 'HumanEval_0',
'round': 1
}
]
}
# Complete pipeline mock ์ƒ์„ฑ
with patch('absolute_zero_reasoner.testtime.complete_pipeline.CompleteTestTimePipeline') as mock_pipeline:
pipeline = CompleteTestTimePipeline(self.config, self.logger)
# _save_azr_training_data ๋ฉ”์„œ๋“œ ํ…Œ์ŠคํŠธ
output_dir = self.test_dir
problem_id = "HumanEval_0"
round_num = 1
saved_files = pipeline._save_azr_training_data(mock_tasks, problem_id, round_num, output_dir)
# ํŒŒ์ผ ์ƒ์„ฑ ๊ฒ€์ฆ
self.assertIn('induction', saved_files)
self.assertTrue(os.path.exists(saved_files['induction']))
# Parquet ํŒŒ์ผ ์ฝ๊ธฐ ํ…Œ์ŠคํŠธ
import pandas as pd
df = pd.read_parquet(saved_files['induction'])
# ๋ฐ์ดํ„ฐ ๊ฒ€์ฆ
self.assertEqual(len(df), 1)
self.assertIn('prompt', df.columns)
self.assertIn('uid', df.columns)
self.assertIn('ipo_group_id', df.columns)
self.assertIn('ground_truth', df.columns)
self.assertIn('basic_accuracy', df.columns)
# ํ”„๋กฌํ”„ํŠธ ํ˜•์‹ ๊ฒ€์ฆ (chat ํ˜•์‹)
prompt_data = df.iloc[0]['prompt']
self.assertIsInstance(prompt_data, list)
self.assertEqual(prompt_data[0]['role'], 'user')
self.assertIn('content', prompt_data[0])
print("โœ… Data converter parquet format test passed")
def test_complete_pipeline_basic_accuracy_update(self):
"""Complete Pipeline์˜ basic_accuracy ์—…๋ฐ์ดํŠธ ํ…Œ์ŠคํŠธ"""
# Mock components
with patch.multiple(
'absolute_zero_reasoner.testtime.complete_pipeline.CompleteTestTimePipeline',
_generate_task_response=Mock(return_value="test response"),
_extract_answer_by_task_type=Mock(return_value="test answer"),
_calculate_task_accuracy=Mock(return_value=0.8)
):
pipeline = CompleteTestTimePipeline(self.config, self.logger)
# Mock task ๋ฐ์ดํ„ฐ
mock_tasks = {
'induction': [
{
'task_id': 'induction_0',
'prompt': 'test prompt',
'expected_solution': 'test solution',
'evaluation_data': {'test': 'data'},
'basic_accuracy': 0.0 # ์ดˆ๊ธฐ๊ฐ’
}
]
}
# Task ํ‰๊ฐ€ ์‹คํ–‰
evaluations = pipeline._evaluate_tasks_with_llm(mock_tasks)
# basic_accuracy ์—…๋ฐ์ดํŠธ ๊ฒ€์ฆ
updated_task = mock_tasks['induction'][0]
self.assertEqual(updated_task['basic_accuracy'], 0.8) # Mock์—์„œ ๋ฐ˜ํ™˜ํ•œ ๊ฐ’
# Evaluation ๊ฒฐ๊ณผ ๊ฒ€์ฆ
self.assertIn('induction', evaluations)
eval_result = evaluations['induction'][0]
self.assertIn('basic_accuracy', eval_result)
self.assertEqual(eval_result['basic_accuracy'], 0.8)
print("โœ… Complete pipeline basic_accuracy update test passed")
def test_iterative_trainer_round_management(self):
"""Iterative Trainer์˜ ๋ผ์šด๋“œ ๊ด€๋ฆฌ ํ…Œ์ŠคํŠธ"""
# Mock benchmark config
benchmark_config = BenchmarkConfig(
name='test_benchmark',
problems_path='/test/path',
max_problems=None
)
problem_ids = ['TestProblem_1', 'TestProblem_2']
with patch.object(IterativeTrainer, '_update_pipeline_model'):
with patch.object(IterativeTrainer, '_train_azr_with_round_data') as mock_train:
with patch.object(IterativeTrainer, '_save_checkpoint'):
with patch('absolute_zero_reasoner.testtime.complete_pipeline.CompleteTestTimePipeline') as mock_pipeline_class:
# Mock pipeline ๊ฒฐ๊ณผ
mock_pipeline = Mock()
mock_pipeline.run_complete_pipeline.return_value = {
'success': True,
'azr_training_data': {'induction': '/test/path/induction.parquet'},
'steps': {
'azr_data_saving': {'total_tasks': 10}
}
}
mock_pipeline_class.return_value = mock_pipeline
# Mock AZR ํ•™์Šต ๊ฒฐ๊ณผ
mock_train.return_value = "/data/RLVR/checkpoints/ttrlvr_azr/round_1"
# Trainer ์ดˆ๊ธฐํ™”
trainer = IterativeTrainer(self.config, self.logger)
trainer.checkpoint_dir = self.test_dir
# ๋‹จ์ผ ๋ผ์šด๋“œ ํ…Œ์ŠคํŠธ (์ „์ฒด 30๋ผ์šด๋“œ๋Š” ๋„ˆ๋ฌด ์˜ค๋ž˜ ๊ฑธ๋ฆผ)
round_result = trainer._run_single_round(benchmark_config, problem_ids, 1)
# ๊ฒฐ๊ณผ ๊ฒ€์ฆ
self.assertTrue(round_result['success'])
self.assertEqual(len(round_result['problems']), 2)
self.assertGreater(len(round_result['training_data_files']), 0)
# ํ†ต๊ณ„ ๊ฒ€์ฆ
stats = round_result['stats']
self.assertEqual(stats['total_problems'], 2)
self.assertEqual(stats['successful_problems'], 2)
self.assertEqual(stats['failed_problems'], 0)
print("โœ… Iterative trainer round management test passed")
def test_data_combination_and_sorting(self):
"""๋ผ์šด๋“œ ๋ฐ์ดํ„ฐ ํ†ตํ•ฉ ๋ฐ ์ •๋ ฌ ํ…Œ์ŠคํŠธ"""
trainer = IterativeTrainer(self.config, self.logger)
# Mock training data files
training_data_files = [
{
'problem_id': 'TestProblem_1',
'files': {
'induction': os.path.join(self.test_dir, 'test_induction_1.parquet')
}
},
{
'problem_id': 'TestProblem_2',
'files': {
'induction': os.path.join(self.test_dir, 'test_induction_2.parquet')
}
}
]
# Mock parquet ํŒŒ์ผ ์ƒ์„ฑ
import pandas as pd
# ์ฒซ ๋ฒˆ์งธ ๋ฌธ์ œ ๋ฐ์ดํ„ฐ
data1 = pd.DataFrame([
{
'uid': 'TestProblem_1_round_1_induction_0',
'ipo_group_id': 'TestProblem_1_program_1_ipo_2',
'basic_accuracy': 0.8
}
])
data1.to_parquet(training_data_files[0]['files']['induction'], index=False)
# ๋‘ ๋ฒˆ์งธ ๋ฌธ์ œ ๋ฐ์ดํ„ฐ
data2 = pd.DataFrame([
{
'uid': 'TestProblem_2_round_1_induction_0',
'ipo_group_id': 'TestProblem_2_program_0_ipo_1',
'basic_accuracy': 0.6
}
])
data2.to_parquet(training_data_files[1]['files']['induction'], index=False)
# ๋ฐ์ดํ„ฐ ํ†ตํ•ฉ ํ…Œ์ŠคํŠธ
combined_path = trainer._combine_round_data(training_data_files, 1)
self.assertIsNotNone(combined_path)
self.assertTrue(os.path.exists(combined_path))
# ํ†ตํ•ฉ๋œ ํŒŒ์ผ ๊ฒ€์ฆ
combined_file = os.path.join(combined_path, 'induction.parquet')
self.assertTrue(os.path.exists(combined_file))
# ๋ฐ์ดํ„ฐ ์ •๋ ฌ ๊ฒ€์ฆ (ipo_group_id๋กœ ์ •๋ ฌ๋˜์–ด์•ผ ํ•จ)
combined_df = pd.read_parquet(combined_file)
self.assertEqual(len(combined_df), 2)
# ipo_group_id ์ •๋ ฌ ๊ฒ€์ฆ
ipo_groups = combined_df['ipo_group_id'].tolist()
self.assertEqual(ipo_groups, sorted(ipo_groups))
print("โœ… Data combination and sorting test passed")
class TestPerformanceAndMemory(unittest.TestCase):
"""์„ฑ๋Šฅ ๋ฐ ๋ฉ”๋ชจ๋ฆฌ ํ…Œ์ŠคํŠธ"""
def setUp(self):
self.config = TestTimeConfig()
self.logger = TestTimeLogger()
def test_memory_cleanup_between_rounds(self):
"""๋ผ์šด๋“œ ๊ฐ„ ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ ํ…Œ์ŠคํŠธ"""
with patch('absolute_zero_reasoner.testtime.complete_pipeline.CompleteTestTimePipeline') as mock_pipeline_class:
mock_pipeline = Mock()
mock_pipeline.model = Mock()
mock_pipeline.tokenizer = Mock()
mock_pipeline_class.return_value = mock_pipeline
trainer = IterativeTrainer(self.config, self.logger)
# ๋ชจ๋ธ ์—…๋ฐ์ดํŠธ ํ…Œ์ŠคํŠธ (๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ ํฌํ•จ)
trainer._update_pipeline_model("/new/model/path")
# ๋ชจ๋ธ๊ณผ ํ† ํฌ๋‚˜์ด์ €๊ฐ€ None์œผ๋กœ ์„ค์ •๋˜์—ˆ๋Š”์ง€ ํ™•์ธ
self.assertIsNone(trainer.complete_pipeline.model)
self.assertIsNone(trainer.complete_pipeline.tokenizer)
print("โœ… Memory cleanup between rounds test passed")
def test_checkpoint_size_and_structure(self):
"""์ฒดํฌํฌ์ธํŠธ ํฌ๊ธฐ ๋ฐ ๊ตฌ์กฐ ํ…Œ์ŠคํŠธ"""
trainer = IterativeTrainer(self.config, self.logger)
test_dir = tempfile.mkdtemp()
trainer.checkpoint_dir = test_dir
try:
# Mock training results
training_results = {
'total_rounds': 30,
'rounds': {
1: {'success': True, 'stats': {'total_tasks': 100}},
2: {'success': True, 'stats': {'total_tasks': 95}},
3: {'success': False, 'error': 'Test error'}
}
}
# ์ฒดํฌํฌ์ธํŠธ ์ €์žฅ
trainer._save_checkpoint(3, "/test/model/path", training_results)
# ์ฒดํฌํฌ์ธํŠธ ํŒŒ์ผ ์กด์žฌ ํ™•์ธ
checkpoint_path = os.path.join(test_dir, "checkpoint_round_3")
self.assertTrue(os.path.exists(checkpoint_path))
checkpoint_file = os.path.join(checkpoint_path, "checkpoint.json")
summary_file = os.path.join(checkpoint_path, "summary.txt")
self.assertTrue(os.path.exists(checkpoint_file))
self.assertTrue(os.path.exists(summary_file))
# ์ฒดํฌํฌ์ธํŠธ ๋‚ด์šฉ ๊ฒ€์ฆ
with open(checkpoint_file, 'r') as f:
checkpoint_data = json.load(f)
self.assertEqual(checkpoint_data['round_num'], 3)
self.assertEqual(checkpoint_data['model_path'], "/test/model/path")
self.assertIn('training_results', checkpoint_data)
# ํŒŒ์ผ ํฌ๊ธฐ ํ™•์ธ (๋„ˆ๋ฌด ํฌ์ง€ ์•Š์€์ง€)
checkpoint_size = os.path.getsize(checkpoint_file)
self.assertLess(checkpoint_size, 1024 * 1024) # 1MB ๋ฏธ๋งŒ
finally:
shutil.rmtree(test_dir)
print("โœ… Checkpoint size and structure test passed")
def run_integration_test():
"""ํ†ตํ•ฉ ํ…Œ์ŠคํŠธ ์‹คํ–‰"""
print("๐Ÿงช TTRLVR + AZR ํ†ตํ•ฉ ์‹œ์Šคํ…œ ํ…Œ์ŠคํŠธ ์‹œ์ž‘")
print("=" * 60)
# ํ…Œ์ŠคํŠธ ์Šค์œ„ํŠธ ์ƒ์„ฑ
loader = unittest.TestLoader()
suite = unittest.TestSuite()
# ํ…Œ์ŠคํŠธ ํด๋ž˜์Šค ์ถ”๊ฐ€
suite.addTests(loader.loadTestsFromTestCase(TestTTRLVRAZRIntegration))
suite.addTests(loader.loadTestsFromTestCase(TestPerformanceAndMemory))
# ํ…Œ์ŠคํŠธ ์‹คํ–‰
runner = unittest.TextTestRunner(verbosity=2)
result = runner.run(suite)
# ๊ฒฐ๊ณผ ์š”์•ฝ
print("\n" + "=" * 60)
print("๐Ÿ“Š ํ…Œ์ŠคํŠธ ๊ฒฐ๊ณผ ์š”์•ฝ:")
print(f" - ์‹คํ–‰๋œ ํ…Œ์ŠคํŠธ: {result.testsRun}")
print(f" - ์„ฑ๊ณต: {result.testsRun - len(result.failures) - len(result.errors)}")
print(f" - ์‹คํŒจ: {len(result.failures)}")
print(f" - ์˜ค๋ฅ˜: {len(result.errors)}")
if result.failures:
print("\nโŒ ์‹คํŒจํ•œ ํ…Œ์ŠคํŠธ:")
for test, traceback in result.failures:
print(f" - {test}: {traceback}")
if result.errors:
print("\n๐Ÿ’ฅ ์˜ค๋ฅ˜ ๋ฐœ์ƒ ํ…Œ์ŠคํŠธ:")
for test, traceback in result.errors:
print(f" - {test}: {traceback}")
success = len(result.failures) == 0 and len(result.errors) == 0
if success:
print("\n๐ŸŽ‰ ๋ชจ๋“  ํ…Œ์ŠคํŠธ ์„ฑ๊ณต!")
else:
print("\nโš ๏ธ ์ผ๋ถ€ ํ…Œ์ŠคํŠธ ์‹คํŒจ ๋˜๋Š” ์˜ค๋ฅ˜ ๋ฐœ์ƒ")
return success
if __name__ == '__main__':
success = run_integration_test()
sys.exit(0 if success else 1)