| """Tests for dataset creation and loading module.""" |
|
|
| import json |
| import tempfile |
| from pathlib import Path |
|
|
| import pytest |
|
|
|
|
| class TestDatasetCreation: |
| """Tests for dataset creation functions.""" |
|
|
| def test_generate_factual_examples_import(self): |
| """Test that _generate_factual_examples can be imported and called.""" |
| from biorlhf.data.dataset import _generate_factual_examples |
|
|
| examples = _generate_factual_examples() |
| assert isinstance(examples, list) |
| assert len(examples) > 0 |
|
|
| def test_factual_examples_structure(self): |
| """Test that factual examples have required fields.""" |
| from biorlhf.data.dataset import _generate_factual_examples |
|
|
| examples = _generate_factual_examples() |
| for ex in examples: |
| assert "instruction" in ex |
| assert "output" in ex |
| |
| assert "input" in ex |
|
|
| def test_generate_comparison_examples(self): |
| """Test comparison example generation.""" |
| from biorlhf.data.dataset import _generate_comparison_examples |
|
|
| examples = _generate_comparison_examples() |
| assert isinstance(examples, list) |
| assert len(examples) > 0 |
|
|
| |
| instructions = [ex["instruction"] for ex in examples] |
| assert any("most sensitive" in instr.lower() for instr in instructions) |
|
|
| def test_generate_interaction_examples(self): |
| """Test interaction prediction example generation.""" |
| from biorlhf.data.dataset import _generate_interaction_examples |
|
|
| examples = _generate_interaction_examples() |
| assert isinstance(examples, list) |
| |
| assert len(examples) == 4 |
|
|
| def test_generate_design_critique_examples(self): |
| """Test experimental design critique example generation.""" |
| from biorlhf.data.dataset import _generate_design_critique_examples |
|
|
| examples = _generate_design_critique_examples() |
| assert isinstance(examples, list) |
| assert len(examples) > 0 |
|
|
| def test_generate_mechanistic_examples(self): |
| """Test mechanistic reasoning example generation.""" |
| from biorlhf.data.dataset import _generate_mechanistic_examples |
|
|
| examples = _generate_mechanistic_examples() |
| assert isinstance(examples, list) |
| assert len(examples) > 0 |
|
|
| def test_generate_calibration_examples(self): |
| """Test uncertainty calibration example generation.""" |
| from biorlhf.data.dataset import _generate_calibration_examples |
|
|
| examples = _generate_calibration_examples() |
| assert isinstance(examples, list) |
| assert len(examples) > 0 |
|
|
| |
| for ex in examples: |
| output = ex["output"].lower() |
| uncertainty_markers = ["cannot", "insufficient", "confidence", "needed", "missing"] |
| has_uncertainty = any(marker in output for marker in uncertainty_markers) |
| assert has_uncertainty, f"Calibration example should express uncertainty: {ex['output'][:100]}" |
|
|
|
|
| class TestCreateSFTDataset: |
| """Tests for the main create_sft_dataset function.""" |
|
|
| def test_creates_dataset_file(self): |
| """Test that create_sft_dataset creates a JSON file.""" |
| from biorlhf.data.dataset import create_sft_dataset |
|
|
| with tempfile.TemporaryDirectory() as tmpdir: |
| output_path = Path(tmpdir) / "test_dataset.json" |
| result = create_sft_dataset(output_path=output_path) |
|
|
| assert output_path.exists() |
| assert isinstance(result, list) |
| assert len(result) > 0 |
|
|
| def test_dataset_format(self): |
| """Test that created dataset has correct format.""" |
| from biorlhf.data.dataset import create_sft_dataset |
|
|
| with tempfile.TemporaryDirectory() as tmpdir: |
| output_path = Path(tmpdir) / "test_dataset.json" |
| result = create_sft_dataset(output_path=output_path) |
|
|
| |
| for ex in result: |
| assert "text" in ex |
| text = ex["text"] |
| |
| assert "### Instruction:" in text |
| assert "### Response:" in text |
|
|
| def test_dataset_json_valid(self): |
| """Test that output file is valid JSON.""" |
| from biorlhf.data.dataset import create_sft_dataset |
|
|
| with tempfile.TemporaryDirectory() as tmpdir: |
| output_path = Path(tmpdir) / "test_dataset.json" |
| create_sft_dataset(output_path=output_path) |
|
|
| with open(output_path) as f: |
| data = json.load(f) |
|
|
| assert isinstance(data, list) |
|
|
| def test_exclude_calibration(self): |
| """Test that calibration examples can be excluded.""" |
| from biorlhf.data.dataset import create_sft_dataset |
|
|
| with tempfile.TemporaryDirectory() as tmpdir: |
| path_with = Path(tmpdir) / "with_cal.json" |
| path_without = Path(tmpdir) / "without_cal.json" |
|
|
| result_with = create_sft_dataset(output_path=path_with, include_calibration=True) |
| result_without = create_sft_dataset(output_path=path_without, include_calibration=False) |
|
|
| |
| assert len(result_with) > len(result_without) |
|
|
| def test_exclude_chain_of_thought(self): |
| """Test that chain-of-thought examples can be excluded.""" |
| from biorlhf.data.dataset import create_sft_dataset |
|
|
| with tempfile.TemporaryDirectory() as tmpdir: |
| path_with = Path(tmpdir) / "with_cot.json" |
| path_without = Path(tmpdir) / "without_cot.json" |
|
|
| result_with = create_sft_dataset(output_path=path_with, include_chain_of_thought=True) |
| result_without = create_sft_dataset(output_path=path_without, include_chain_of_thought=False) |
|
|
| |
| assert len(result_with) > len(result_without) |
|
|
|
|
| class TestLoadDataset: |
| """Tests for the load_dataset function.""" |
|
|
| def test_load_dataset_basic(self): |
| """Test basic dataset loading.""" |
| from biorlhf.data.dataset import create_sft_dataset, load_dataset |
|
|
| with tempfile.TemporaryDirectory() as tmpdir: |
| output_path = Path(tmpdir) / "test_dataset.json" |
| create_sft_dataset(output_path=output_path) |
|
|
| |
| dataset = load_dataset(output_path, test_size=0) |
|
|
| assert hasattr(dataset, "__len__") |
| assert len(dataset) > 0 |
|
|
| def test_load_dataset_with_split(self): |
| """Test dataset loading with train/test split.""" |
| from biorlhf.data.dataset import create_sft_dataset, load_dataset |
|
|
| with tempfile.TemporaryDirectory() as tmpdir: |
| output_path = Path(tmpdir) / "test_dataset.json" |
| create_sft_dataset(output_path=output_path) |
|
|
| |
| splits = load_dataset(output_path, test_size=0.2) |
|
|
| assert "train" in splits |
| assert "test" in splits |
| assert len(splits["train"]) > len(splits["test"]) |
|
|
| def test_load_specific_split(self): |
| """Test loading a specific split.""" |
| from biorlhf.data.dataset import create_sft_dataset, load_dataset |
|
|
| with tempfile.TemporaryDirectory() as tmpdir: |
| output_path = Path(tmpdir) / "test_dataset.json" |
| create_sft_dataset(output_path=output_path) |
|
|
| |
| train_dataset = load_dataset(output_path, split="train", test_size=0.2) |
|
|
| |
| assert not isinstance(train_dataset, dict) |
| assert hasattr(train_dataset, "__len__") |
|
|