Capibara / tests /test_train.py
micheleminervini0604's picture
Added tests for the train.py module, fixed an error when loading the dataset
3aaced5
from pathlib import Path
import sys
import tempfile
from unittest.mock import patch
from datasets import Dataset
import numpy as np
import pandas as pd
import pytest
sys.path.insert(0, str(Path(__file__).parent.parent))
from syntetic_issue_report_data_generation.config import MODEL_CONFIGS, DATASET_CONFIGs
from syntetic_issue_report_data_generation.modeling.train import (
init_parser,
load_and_prepare_data,
train_model_setfit,
train_model_transformers,
)
@pytest.fixture
def temp_data_dir():
"""Create a temporary directory for test data files."""
with tempfile.TemporaryDirectory() as tmpdirname:
yield Path(tmpdirname)
@pytest.fixture
def sample_train_data():
"""Create sample training data with balanced classes."""
return pd.DataFrame(
{
"title": [
"Bug in login",
"Feature request",
"Performance issue",
"UI problem",
"Bug in logout",
"Add search",
"Memory leak",
"New API endpoint",
"Crash on startup",
"Enhancement needed",
],
"body": [
"Cannot login to system",
"Add dark mode feature",
"Slow loading times",
"Button misaligned",
"Cannot logout properly",
"Need search functionality",
"High memory usage",
"REST API needed",
"Application crashes",
"Improve user experience",
],
"label": [
"bug",
"enhancement",
"bug",
"bug",
"bug",
"enhancement",
"bug",
"enhancement",
"bug",
"enhancement",
],
}
)
@pytest.fixture
def sample_imbalanced_data():
"""Create sample data with imbalanced classes (for stratified sampling test)."""
return pd.DataFrame(
{
"title": [
"Bug 1",
"Bug 2",
"Bug 3",
"Bug 4",
"Bug 5",
"Bug 6",
"Bug 7",
"Bug 8",
"Enhancement 1",
"Enhancement 2",
],
"body": [
"Bug body 1",
"Bug body 2",
"Bug body 3",
"Bug body 4",
"Bug body 5",
"Bug body 6",
"Bug body 7",
"Bug body 8",
"Enhancement body 1",
"Enhancement body 2",
],
"label": [
"bug",
"bug",
"bug",
"bug",
"bug",
"bug",
"bug",
"bug",
"enhancement",
"enhancement",
],
}
)
@pytest.fixture
def train_config_with_title(temp_data_dir, sample_train_data):
"""Create train config with title and body columns."""
train_path = temp_data_dir / "train_with_title.csv"
sample_train_data.to_csv(train_path, index=False)
return {
"data_path": "train_with_title.csv",
"label_col": "label",
"title_col": "title",
"body_col": "body",
"sep": ",",
}
@pytest.fixture
def imbalanced_train_config(temp_data_dir, sample_imbalanced_data):
"""Create train config with imbalanced data."""
train_path = temp_data_dir / "train_imbalanced.csv"
sample_imbalanced_data.to_csv(train_path, index=False)
return {
"data_path": "train_imbalanced.csv",
"label_col": "label",
"title_col": "title",
"body_col": "body",
"sep": ",",
}
@pytest.fixture
def minimal_train_data():
"""Create minimal training data for quick training tests."""
return pd.DataFrame(
{
"title": [
"Bug 1",
"Bug 2",
"Enhancement 1",
"Enhancement 2",
"Bug 3",
"Enhancement 3",
],
"body": [
"Bug body 1",
"Bug body 2",
"Enh body 1",
"Enh body 2",
"Bug body 3",
"Enh body 3",
],
"label": ["bug", "bug", "enhancement", "enhancement", "bug", "enhancement"],
}
)
@pytest.fixture
def minimal_train_config(temp_data_dir, minimal_train_data):
"""Create train config with minimal data for fast training."""
train_path = temp_data_dir / "minimal_train.csv"
minimal_train_data.to_csv(train_path, index=False)
return {
"data_path": "minimal_train.csv",
"label_col": "label",
"title_col": "title",
"body_col": "body",
"sep": ",",
}
@pytest.fixture
def minimal_model_config_setfit():
"""Create minimal SetFit model configuration for testing."""
return {
"model_checkpoint": "sentence-transformers/paraphrase-MiniLM-L3-v2", # Small, fast model
"params": {"num_epochs": 1, "batch_size": 4, "num_iterations": 5, "max_length": 64},
}
@pytest.fixture
def minimal_model_config_transformers():
"""Create minimal Transformers model configuration for testing."""
return {
"model_checkpoint": "prajjwal1/bert-tiny", # Very small BERT model
"params": {
"num_train_epochs": 1,
"per_device_train_batch_size": 2,
"per_device_eval_batch_size": 2,
"learning_rate": 5e-5,
"warmup_steps": 0,
"weight_decay": 0.01,
"logging_steps": 1,
},
}
class TestDataLoadingAndPreparation:
"""Test class for data loading and preparation functionality."""
def test_load_data_with_valid_config(self, train_config_with_title, temp_data_dir):
"""Verify that data loads correctly with valid train dataset configuration."""
with patch(
"syntetic_issue_report_data_generation.modeling.train.SOFT_CLEANED_DATA_DIR",
temp_data_dir,
):
train_ds, test_ds, eval_strategy = load_and_prepare_data(
train_config_with_title, test_config=None, test_size=0.2
)
# Check that datasets were created
assert isinstance(train_ds, Dataset)
assert isinstance(test_ds, Dataset)
# Check that datasets have correct columns
assert set(train_ds.column_names) == {"text", "label"}
assert set(test_ds.column_names) == {"text", "label"}
# Check that datasets are not empty
assert len(train_ds) > 0
assert len(test_ds) > 0
# Check total samples
assert len(train_ds) + len(test_ds) == 10
# Check label encoder is attached
assert hasattr(train_ds, "label_encoder")
assert hasattr(test_ds, "label_encoder")
# Check that labels are integers
assert all(isinstance(label, int) for label in train_ds["label"])
assert all(isinstance(label, int) for label in test_ds["label"])
def test_load_data_creates_holdout_split(self, train_config_with_title, temp_data_dir):
"""Verify holdout split is created when no test dataset is provided."""
with patch(
"syntetic_issue_report_data_generation.modeling.train.SOFT_CLEANED_DATA_DIR",
temp_data_dir,
):
train_ds, test_ds, eval_strategy = load_and_prepare_data(
train_config_with_title, test_config=None, test_size=0.2
)
# Check eval strategy is holdout
assert eval_strategy == "holdout"
# Check that both datasets exist
assert len(train_ds) > 0
assert len(test_ds) > 0
# Check that split is approximately correct (80/20 split of 10 samples)
total_samples = len(train_ds) + len(test_ds)
assert total_samples == 10
assert len(test_ds) == 2 # 20% of 10 = 2
assert len(train_ds) == 8 # 80% of 10 = 8
# Verify no data leakage (no overlap between train and test)
train_texts = set(train_ds["text"])
test_texts = set(test_ds["text"])
assert len(train_texts.intersection(test_texts)) == 0
def test_label_encoding_consistency(self, train_config_with_title, temp_data_dir):
"""Verify labels are encoded consistently across train/test sets."""
with patch(
"syntetic_issue_report_data_generation.modeling.train.SOFT_CLEANED_DATA_DIR",
temp_data_dir,
):
train_ds, test_ds, eval_strategy = load_and_prepare_data(
train_config_with_title, test_config=None, test_size=0.2
)
# Check that label encoders are the same object
assert train_ds.label_encoder is test_ds.label_encoder
# Check that labels are integers
assert all(isinstance(label, int) for label in train_ds["label"])
assert all(isinstance(label, int) for label in test_ds["label"])
# Check that label classes are consistent
train_label_classes = train_ds.label_encoder.classes_
test_label_classes = test_ds.label_encoder.classes_
assert list(train_label_classes) == list(test_label_classes)
# Check that we have the expected classes (bug and enhancement)
expected_classes = sorted(["bug", "enhancement"])
actual_classes = sorted(train_ds.label_encoder.classes_)
assert actual_classes == expected_classes
# Check that encoded labels are in valid range [0, num_classes)
num_classes = len(train_ds.label_encoder.classes_)
assert num_classes == 2
assert all(0 <= label < num_classes for label in train_ds["label"])
assert all(0 <= label < num_classes for label in test_ds["label"])
# Check that both classes appear in train set (due to stratification)
train_unique_labels = set(train_ds["label"])
assert len(train_unique_labels) == 2
def test_text_column_creation_with_title_and_body(
self, train_config_with_title, temp_data_dir
):
"""Verify text column combines title and body correctly."""
with patch(
"syntetic_issue_report_data_generation.modeling.train.SOFT_CLEANED_DATA_DIR",
temp_data_dir,
):
train_ds, test_ds, eval_strategy = load_and_prepare_data(
train_config_with_title, test_config=None, test_size=0.2
)
# Check that text column exists and is the only text column
assert "text" in train_ds.column_names
assert "text" in test_ds.column_names
assert "title" not in train_ds.column_names
assert "body" not in train_ds.column_names
# Check that all text entries are non-empty strings
assert all(isinstance(text, str) and len(text) > 0 for text in train_ds["text"])
assert all(isinstance(text, str) and len(text) > 0 for text in test_ds["text"])
# Check that text contains content (not just whitespace)
assert all(len(text.strip()) > 0 for text in train_ds["text"])
assert all(len(text.strip()) > 0 for text in test_ds["text"])
# Check that text is longer than just title or body alone
# (indicating concatenation happened)
for text in train_ds["text"]:
# Text should have reasonable length (at least 10 chars)
assert len(text) >= 10
def test_max_train_samples_stratified_sampling(self, imbalanced_train_config, temp_data_dir):
"""Verify stratified sampling works correctly when max_train_samples is specified."""
with patch(
"syntetic_issue_report_data_generation.modeling.train.SOFT_CLEANED_DATA_DIR",
temp_data_dir,
):
# Original data has 10 samples: 8 bugs, 2 enhancements
# Request only 4 samples
train_ds, test_ds, eval_strategy = load_and_prepare_data(
imbalanced_train_config, test_config=None, test_size=0.2, max_train_samples=4
)
# Check that train dataset size is reduced to 4
assert len(train_ds) == 4
# Check that we still have both classes (stratified sampling)
unique_labels = set(train_ds["label"])
assert len(unique_labels) == 2, "Stratified sampling should preserve both classes"
# Check that class distribution is approximately maintained
# Original: 80% bug, 20% enhancement
# With 4 samples: should have ~3 bugs, ~1 enhancement
label_counts = {}
for label in train_ds["label"]:
label_name = train_ds.label_encoder.inverse_transform([label])[0]
label_counts[label_name] = label_counts.get(label_name, 0) + 1
# At least one sample from minority class
assert label_counts.get("enhancement", 0) >= 1
# Majority class should have more samples
assert label_counts.get("bug", 0) >= label_counts.get("enhancement", 0)
# Test dataset should remain at 20% of original (2 samples)
assert len(test_ds) == 2
# Total should be less than original
assert len(train_ds) + len(test_ds) < 10
class TestConfiguration:
"""Test class for configuration and argument parsing functionality."""
def test_parser_accepts_valid_arguments(self):
"""Verify parser accepts all valid combinations of arguments."""
parser = init_parser()
# Get first valid dataset and model from configs
valid_dataset = list(DATASET_CONFIGs.keys())[0]
valid_model = list(MODEL_CONFIGS.keys())[0]
# Test 1: Minimal required arguments
args = parser.parse_args(["--train-dataset", valid_dataset, "--model-name", valid_model])
assert args.train_dataset == valid_dataset
assert args.model_name == valid_model
assert args.test_dataset is None
assert args.test_size == 0.2 # default value
assert args.max_train_samples is None # default value
assert args.use_setfit is False # default value
assert args.run_name is None # default value
# Test 2: All arguments provided
if len(DATASET_CONFIGs.keys()) > 1:
valid_test_dataset = list(DATASET_CONFIGs.keys())[1]
else:
valid_test_dataset = valid_dataset
args = parser.parse_args(
[
"--train-dataset",
valid_dataset,
"--test-dataset",
valid_test_dataset,
"--model-name",
valid_model,
"--use-setfit",
"--test-size",
"0.3",
"--max-train-samples",
"100",
"--run-name",
"test_run",
]
)
assert args.train_dataset == valid_dataset
assert args.test_dataset == valid_test_dataset
assert args.model_name == valid_model
assert args.use_setfit is True
assert args.test_size == 0.3
assert args.max_train_samples == 100
assert args.run_name == "test_run"
# Test 3: Only use-setfit flag
args = parser.parse_args(
["--train-dataset", valid_dataset, "--model-name", valid_model, "--use-setfit"]
)
assert args.use_setfit is True
# Test 4: Custom test-size
args = parser.parse_args(
["--train-dataset", valid_dataset, "--model-name", valid_model, "--test-size", "0.15"]
)
assert args.test_size == 0.15
# Test 5: Custom max-train-samples
args = parser.parse_args(
[
"--train-dataset",
valid_dataset,
"--model-name",
valid_model,
"--max-train-samples",
"500",
]
)
assert args.max_train_samples == 500
def test_parser_rejects_invalid_dataset_names(self):
"""Verify parser rejects dataset names not in DATASET_CONFIGs."""
parser = init_parser()
# Get valid model
valid_model = list(MODEL_CONFIGS.keys())[0]
# Test 1: Invalid train dataset
with pytest.raises(SystemExit):
parser.parse_args(
["--train-dataset", "invalid_dataset_name", "--model-name", valid_model]
)
# Test 2: Invalid test dataset
valid_dataset = list(DATASET_CONFIGs.keys())[0]
with pytest.raises(SystemExit):
parser.parse_args(
[
"--train-dataset",
valid_dataset,
"--test-dataset",
"invalid_test_dataset",
"--model-name",
valid_model,
]
)
# Test 3: Invalid model name
with pytest.raises(SystemExit):
parser.parse_args(
["--train-dataset", valid_dataset, "--model-name", "invalid_model_name"]
)
# Test 4: Missing required argument (train-dataset)
with pytest.raises(SystemExit):
parser.parse_args(["--model-name", valid_model])
# Test 5: Missing required argument (model-name)
with pytest.raises(SystemExit):
parser.parse_args(["--train-dataset", valid_dataset])
class TestTrainingPipeline:
@pytest.mark.slow
def test_setfit_training_completes(
self, minimal_train_config, minimal_model_config_setfit, temp_data_dir
):
"""Verify SetFit training runs without errors (using minimal data)."""
with patch(
"syntetic_issue_report_data_generation.modeling.train.SOFT_CLEANED_DATA_DIR",
temp_data_dir,
):
# Load data
train_ds, test_ds, eval_strategy = load_and_prepare_data(
minimal_train_config,
test_config=None,
test_size=0.33, # 2 samples for test, 4 for train
)
# Mock MLflow to avoid logging during tests
with patch("syntetic_issue_report_data_generation.modeling.train.mlflow"):
# Train model
result = train_model_setfit(minimal_model_config_setfit, train_ds, test_ds)
# Check that result is returned
assert result is not None
# Check that result has expected structure
model, metrics, y_true, y_pred, model_type = result
# Check model type
assert model_type == "setfit"
# Check that model is returned
assert model is not None
# Check that metrics are computed
assert isinstance(metrics, dict)
assert "accuracy" in metrics
assert "f1_macro" in metrics
assert "f1_weighted" in metrics
# Check that metrics are in valid range [0, 1]
assert 0 <= metrics["accuracy"] <= 1
assert 0 <= metrics["f1_macro"] <= 1
assert 0 <= metrics["f1_weighted"] <= 1
# Check that predictions are returned
assert y_true is not None
assert y_pred is not None
# Check that predictions have correct length
assert len(y_true) == len(test_ds)
assert len(y_pred) == len(test_ds)
# Check that predictions are in valid label space
num_classes = len(train_ds.label_encoder.classes_)
assert all(0 <= pred < num_classes for pred in y_pred)
assert all(0 <= true < num_classes for true in y_true)
@pytest.mark.slow
def test_transformers_training_completes(
self, minimal_train_config, minimal_model_config_transformers, temp_data_dir
):
"""Verify Transformers training runs without errors (using minimal data)."""
with patch(
"syntetic_issue_report_data_generation.modeling.train.SOFT_CLEANED_DATA_DIR",
temp_data_dir,
):
# Load data
train_ds, test_ds, eval_strategy = load_and_prepare_data(
minimal_train_config,
test_config=None,
test_size=0.33, # 2 samples for test, 4 for train
)
# Mock MLflow to avoid logging during tests
with patch("syntetic_issue_report_data_generation.modeling.train.mlflow"):
# Train model
result = train_model_transformers(
minimal_model_config_transformers, train_ds, test_ds
)
# Check that result is returned
assert result is not None
# Check that result has expected structure
model_tuple, metrics, y_true, y_pred, model_type = result
# Check model type
assert model_type == "transformers"
# Check that model and tokenizer are returned
assert model_tuple is not None
assert isinstance(model_tuple, tuple)
assert len(model_tuple) == 2
model, tokenizer = model_tuple
assert model is not None
assert tokenizer is not None
# Check that metrics are computed
assert isinstance(metrics, dict)
# Transformers returns metrics with eval_ prefix
assert any("accuracy" in key for key in metrics.keys())
# Extract accuracy value (could be 'accuracy' or 'eval_accuracy')
accuracy_key = [k for k in metrics.keys() if "accuracy" in k][0]
accuracy = metrics[accuracy_key]
assert 0 <= accuracy <= 1
# Check that predictions are returned
assert y_true is not None
assert y_pred is not None
# Check that predictions have correct length
assert len(y_true) == len(test_ds)
assert len(y_pred) == len(test_ds)
# Check that predictions are in valid label space
num_classes = len(train_ds.label_encoder.classes_)
assert all(0 <= pred < num_classes for pred in y_pred)
assert all(0 <= true < num_classes for true in y_true)
# Check that predictions are numpy arrays or lists of integers
assert all(isinstance(pred, (int, np.integer)) for pred in y_pred)
assert all(isinstance(true, (int, np.integer)) for true in y_true)
class TestErrorHandling:
"""Test class for error handling functionality."""
def test_missing_train_file_raises_error(self, temp_data_dir):
"""Verify appropriate error when train file doesn't exist."""
# Create config pointing to non-existent file
missing_file_config = {
"data_path": "non_existent_file.csv",
"label_col": "label",
"title_col": "title",
"body_col": "body",
"sep": ",",
}
with patch(
"syntetic_issue_report_data_generation.modeling.train.SOFT_CLEANED_DATA_DIR",
temp_data_dir,
):
# Should call sys.exit(1) when file doesn't exist
with pytest.raises(SystemExit) as excinfo:
load_and_prepare_data(missing_file_config, test_config=None, test_size=0.2)
# Check that exit code is 1
assert excinfo.value.code == 1
def test_invalid_label_column(self, temp_data_dir):
"""Verify error handling when specified label column doesn't exist."""
# Create data with specific columns
sample_data = pd.DataFrame(
{
"title": ["Bug 1", "Enhancement 1"],
"body": ["Bug body", "Enhancement body"],
"type": ["bug", "enhancement"], # Different column name
}
)
# Save to file
train_path = temp_data_dir / "invalid_label_col.csv"
sample_data.to_csv(train_path, index=False)
# Create config with wrong label column name
invalid_label_config = {
"data_path": "invalid_label_col.csv",
"label_col": "label", # This column doesn't exist
"title_col": "title",
"body_col": "body",
"sep": ",",
}
with patch(
"syntetic_issue_report_data_generation.modeling.train.SOFT_CLEANED_DATA_DIR",
temp_data_dir,
):
# Should call sys.exit(1) when label column doesn't exist
with pytest.raises(SystemExit) as excinfo:
load_and_prepare_data(invalid_label_config, test_config=None, test_size=0.2)
# Check that exit code is 1
assert excinfo.value.code == 1
def test_invalid_body_column(self, temp_data_dir):
"""Verify error handling when specified body column doesn't exist."""
# Create data with specific columns
sample_data = pd.DataFrame(
{
"title": ["Bug 1", "Enhancement 1"],
"description": ["Bug body", "Enhancement body"], # Different column name
"label": ["bug", "enhancement"],
}
)
# Save to file
train_path = temp_data_dir / "invalid_body_col.csv"
sample_data.to_csv(train_path, index=False)
# Create config with wrong body column name
invalid_body_config = {
"data_path": "invalid_body_col.csv",
"label_col": "label",
"title_col": "title",
"body_col": "body", # This column doesn't exist
"sep": ",",
}
with patch(
"syntetic_issue_report_data_generation.modeling.train.SOFT_CLEANED_DATA_DIR",
temp_data_dir,
):
# Should call sys.exit(1) when body column doesn't exist
with pytest.raises(SystemExit) as excinfo:
load_and_prepare_data(invalid_body_config, test_config=None, test_size=0.2)
# Check that exit code is 1
assert excinfo.value.code == 1
class TestEdgeCases:
"""Test class for edge case scenarios."""
def test_very_small_dataset(self, temp_data_dir):
"""Verify training with very small datasets (< 10 samples)."""
# Create very small dataset (6 samples total, 3 per class)
very_small_data = pd.DataFrame(
{
"title": ["Bug 1", "Bug 2", "Bug 3", "Enh 1", "Enh 2", "Enh 3"],
"body": [
"Small bug 1",
"Small bug 2",
"Small bug 3",
"Small enh 1",
"Small enh 2",
"Small enh 3",
],
"label": ["bug", "bug", "bug", "enhancement", "enhancement", "enhancement"],
}
)
# Save to file
train_path = temp_data_dir / "very_small.csv"
very_small_data.to_csv(train_path, index=False)
small_config = {
"data_path": "very_small.csv",
"label_col": "label",
"title_col": "title",
"body_col": "body",
"sep": ",",
}
with patch(
"syntetic_issue_report_data_generation.modeling.train.SOFT_CLEANED_DATA_DIR",
temp_data_dir,
):
# Load data with small test split to ensure at least 1 sample per class in train
train_ds, test_ds, eval_strategy = load_and_prepare_data(
small_config,
test_config=None,
test_size=0.33, # 2 samples for test (1 per class), 4 for train
)
# Check that datasets were created despite small size
assert isinstance(train_ds, Dataset)
assert isinstance(test_ds, Dataset)
# Check that both datasets have samples
assert len(train_ds) > 0
assert len(test_ds) > 0
# Check total is preserved
assert len(train_ds) + len(test_ds) == 6
# Check that stratification preserved both classes in train set
train_unique_labels = set(train_ds["label"])
assert len(train_unique_labels) >= 1 # At least one class
# Check that we have valid label encoding
num_classes = len(train_ds.label_encoder.classes_)
assert num_classes == 2
assert all(0 <= label < num_classes for label in train_ds["label"])
assert all(0 <= label < num_classes for label in test_ds["label"])
# Check that text was properly created
assert all(isinstance(text, str) and len(text) > 0 for text in train_ds["text"])
assert all(isinstance(text, str) and len(text) > 0 for text in test_ds["text"])
class TestOutputValidation:
"""Test class for output validation functionality."""
@pytest.mark.slow
def test_predictions_match_label_space(
self,
minimal_train_config,
minimal_model_config_setfit,
minimal_model_config_transformers,
temp_data_dir,
):
"""Verify predictions are within valid label space."""
with patch(
"syntetic_issue_report_data_generation.modeling.train.SOFT_CLEANED_DATA_DIR",
temp_data_dir,
):
# Load data
train_ds, test_ds, eval_strategy = load_and_prepare_data(
minimal_train_config,
test_config=None,
test_size=0.33, # 2 samples for tfest, 4 for train
)
# Get the valid label space
num_classes = len(train_ds.label_encoder.classes_)
valid_label_space = set(range(num_classes))
# Mock MLflow to avoid logging during tests
with patch("syntetic_issue_report_data_generation.modeling.train.mlflow"):
# Test with SetFit
model, metrics, y_true, y_pred, model_type = train_model_setfit(
minimal_model_config_setfit, train_ds, test_ds
)
# Check that all predictions are in valid label space
assert all(
pred in valid_label_space for pred in y_pred
), f"SetFit predictions contain invalid labels. Valid: {valid_label_space}, Got: {set(y_pred)}"
# Check that all true labels are in valid label space
assert all(
true in valid_label_space for true in y_true
), f"True labels contain invalid values. Valid: {valid_label_space}, Got: {set(y_true)}"
# Check that predictions are within [0, num_classes)
assert all(
0 <= pred < num_classes for pred in y_pred
), f"SetFit predictions out of range [0, {num_classes})"
# Check that y_true matches the original test labels
assert list(y_true) == list(
test_ds["label"]
), "True labels don't match original test dataset labels"
# Test with Transformers
(model_t, tokenizer), metrics_t, y_true_t, y_pred_t, model_type_t = (
train_model_transformers(minimal_model_config_transformers, train_ds, test_ds)
)
# Check that all predictions are in valid label space
assert all(
pred in valid_label_space for pred in y_pred_t
), f"Transformers predictions contain invalid labels. Valid: {valid_label_space}, Got: {set(y_pred_t)}"
# Check that all true labels are in valid label space
assert all(
true in valid_label_space for true in y_true_t
), f"True labels contain invalid values. Valid: {valid_label_space}, Got: {set(y_true_t)}"
# Check that predictions are within [0, num_classes)
assert all(
0 <= pred < num_classes for pred in y_pred_t
), f"Transformers predictions out of range [0, {num_classes})"
# Check that y_true matches the original test labels
assert list(y_true_t) == list(
test_ds["label"]
), "True labels don't match original test dataset labels"
# Additional check: verify predictions are integers
assert all(
isinstance(pred, (int, np.integer)) for pred in y_pred
), "SetFit predictions must be integers"
assert all(
isinstance(pred, (int, np.integer)) for pred in y_pred_t
), "Transformers predictions must be integers"
# Check that at least some predictions were made (not all same)
# This is a sanity check - with random initialization, we should get some variation
# (Though with very small data, it's possible all predictions are the same)
unique_preds = len(set(y_pred))
unique_preds_t = len(set(y_pred_t))
assert unique_preds >= 1, "SetFit made no predictions"
assert unique_preds_t >= 1, "Transformers made no predictions"