| """Tests for src/train.py - Training pipeline helper functions.""" |
|
|
| import numpy as np |
| import pandas as pd |
|
|
| from src.train import ( |
| apply_cardinality_reduction, |
| compute_currency_rates, |
| drop_other_rows, |
| extract_valid_categories, |
| filter_salaries, |
| ) |
|
|
|
|
| def _make_salary_df(countries=None, salaries=None, n=100) -> pd.DataFrame: |
| """Create a minimal DataFrame resembling the survey data.""" |
| if salaries is not None: |
| n = len(salaries) |
| if countries is not None: |
| n = len(countries) |
| if countries is None: |
| countries = ["United States of America"] * n |
| if salaries is None: |
| rng = np.random.default_rng(42) |
| salaries = rng.integers(30000, 200000, size=n).astype(float) |
| rng = np.random.default_rng(42) |
| return pd.DataFrame( |
| { |
| "Country": countries, |
| "YearsCode": rng.integers(0, 30, size=n).astype(float), |
| "WorkExp": rng.integers(0, 20, size=n).astype(float), |
| "EdLevel": ["Bachelor's degree (B.A., B.S., B.Eng., etc.)"] * n, |
| "DevType": ["Developer, full-stack"] * n, |
| "Industry": ["Software Development"] * n, |
| "Age": ["25-34 years old"] * n, |
| "ICorPM": ["Individual contributor"] * n, |
| "OrgSize": ["20 to 99 employees"] * n, |
| "Employment": ["Employed"] * n, |
| "Currency": ["USD United States Dollar"] * n, |
| "CompTotal": salaries, |
| "ConvertedCompYearly": salaries, |
| } |
| ) |
|
|
|
|
| class TestFilterSalaries: |
| """Tests for filter_salaries().""" |
|
|
| def test_removes_below_min_salary(self): |
| """Rows with salary below min_salary are removed.""" |
| salaries = [500.0] * 5 + [2000.0] * 20 + [50000.0] * 20 |
| df = _make_salary_df(salaries=salaries) |
| config = { |
| "data": { |
| "min_salary": 1000, |
| "lower_percentile": 0, |
| "upper_percentile": 100, |
| } |
| } |
| result = filter_salaries(df, config) |
| assert (result["ConvertedCompYearly"] > 1000).all() |
| assert len(result) < len(df) |
|
|
| def test_removes_outliers_by_percentile(self): |
| """Per-country percentile outlier removal works.""" |
| salaries = [10000.0] * 50 + [500000.0] + [10000.0] * 49 |
| df = _make_salary_df(salaries=salaries) |
| config = { |
| "data": { |
| "min_salary": 1000, |
| "lower_percentile": 2, |
| "upper_percentile": 98, |
| } |
| } |
| result = filter_salaries(df, config) |
| assert len(result) < len(df) |
|
|
| def test_drops_missing_target(self): |
| """Rows with NaN target are dropped.""" |
| df = _make_salary_df(salaries=[50000.0, np.nan, 60000.0]) |
| config = { |
| "data": { |
| "min_salary": 1000, |
| "lower_percentile": 0, |
| "upper_percentile": 100, |
| } |
| } |
| result = filter_salaries(df, config) |
| assert not result["ConvertedCompYearly"].isna().any() |
|
|
| def test_returns_dataframe(self): |
| """Returns a pandas DataFrame.""" |
| df = _make_salary_df() |
| config = { |
| "data": { |
| "min_salary": 1000, |
| "lower_percentile": 2, |
| "upper_percentile": 98, |
| } |
| } |
| result = filter_salaries(df, config) |
| assert isinstance(result, pd.DataFrame) |
|
|
|
|
| class TestApplyCardinalityReduction: |
| """Tests for apply_cardinality_reduction().""" |
|
|
| def test_normalizes_unicode_apostrophes(self): |
| """Unicode right single quotation marks are replaced.""" |
| df = _make_salary_df(n=100) |
| df["EdLevel"] = "Master\u2019s degree" |
| result = apply_cardinality_reduction(df) |
| |
| assert "\u2019" not in result["EdLevel"].iloc[0] |
|
|
| def test_does_not_modify_original(self): |
| """The input DataFrame is not modified.""" |
| df = _make_salary_df(n=5) |
| original_country = df["Country"].iloc[0] |
| apply_cardinality_reduction(df) |
| assert df["Country"].iloc[0] == original_country |
|
|
| def test_rare_categories_become_other(self): |
| """Categories below min_frequency are grouped into 'Other'.""" |
| countries = ["United States of America"] * 100 + ["Narnia"] * 2 |
| df = _make_salary_df(countries=countries, n=102) |
| result = apply_cardinality_reduction(df) |
| assert "Narnia" not in result["Country"].values |
| assert "Other" in result["Country"].values |
|
|
|
|
| class TestDropOtherRows: |
| """Tests for drop_other_rows().""" |
|
|
| def test_drops_other_from_specified_columns(self): |
| """Rows with 'Other' in specified columns are dropped.""" |
| df = pd.DataFrame( |
| { |
| "Country": ["USA", "Other", "Germany"], |
| "DevType": ["Dev", "Dev", "Other"], |
| "EdLevel": ["BS", "BS", "BS"], |
| "Industry": ["SW", "SW", "SW"], |
| "Age": ["25-34", "25-34", "25-34"], |
| "ICorPM": ["IC", "IC", "IC"], |
| "OrgSize": ["Small", "Small", "Small"], |
| "Employment": ["FT", "FT", "FT"], |
| } |
| ) |
| config = { |
| "features": { |
| "cardinality": { |
| "other_category": "Other", |
| "drop_other_from": ["Country", "DevType"], |
| } |
| } |
| } |
| result = drop_other_rows(df, config) |
| assert len(result) == 1 |
| assert result.iloc[0]["Country"] == "USA" |
|
|
| def test_no_drop_when_list_empty(self): |
| """No rows dropped when drop_other_from is empty.""" |
| df = pd.DataFrame( |
| { |
| "Country": ["USA", "Other"], |
| "DevType": ["Dev", "Other"], |
| "EdLevel": ["BS", "BS"], |
| "Industry": ["SW", "SW"], |
| "Age": ["25-34", "25-34"], |
| "ICorPM": ["IC", "IC"], |
| "OrgSize": ["Small", "Small"], |
| "Employment": ["FT", "FT"], |
| } |
| ) |
| config = { |
| "features": { |
| "cardinality": { |
| "other_category": "Other", |
| "drop_other_from": [], |
| } |
| } |
| } |
| result = drop_other_rows(df, config) |
| assert len(result) == 2 |
|
|
| def test_uses_configured_other_name(self): |
| """Uses the configured other_category name for matching.""" |
| df = pd.DataFrame( |
| { |
| "Country": ["USA", "Misc"], |
| "DevType": ["Dev", "Dev"], |
| "EdLevel": ["BS", "BS"], |
| "Industry": ["SW", "SW"], |
| "Age": ["25-34", "25-34"], |
| "ICorPM": ["IC", "IC"], |
| "OrgSize": ["Small", "Small"], |
| "Employment": ["FT", "FT"], |
| } |
| ) |
| config = { |
| "features": { |
| "cardinality": { |
| "other_category": "Misc", |
| "drop_other_from": ["Country"], |
| } |
| } |
| } |
| result = drop_other_rows(df, config) |
| assert len(result) == 1 |
|
|
|
|
| class TestExtractValidCategories: |
| """Tests for extract_valid_categories().""" |
|
|
| def test_returns_sorted_unique_values(self): |
| """Returns sorted unique values for each categorical feature.""" |
| df = pd.DataFrame( |
| { |
| "Country": ["Germany", "USA", "Germany"], |
| "EdLevel": ["BS", "MS", "BS"], |
| "DevType": ["Front", "Back", "Front"], |
| "Industry": ["SW", "Fin", "SW"], |
| "Age": ["25-34", "35-44", "25-34"], |
| "ICorPM": ["IC", "PM", "IC"], |
| "OrgSize": ["Small", "Large", "Small"], |
| "Employment": ["FT", "PT", "FT"], |
| } |
| ) |
| result = extract_valid_categories(df) |
| assert result["Country"] == ["Germany", "USA"] |
| assert result["EdLevel"] == ["BS", "MS"] |
| assert result["ICorPM"] == ["IC", "PM"] |
| assert result["OrgSize"] == ["Large", "Small"] |
| assert result["Employment"] == ["FT", "PT"] |
|
|
| def test_all_categorical_features_present(self): |
| """All 8 categorical features are present as keys.""" |
| df = pd.DataFrame( |
| { |
| "Country": ["USA"], |
| "EdLevel": ["BS"], |
| "DevType": ["Dev"], |
| "Industry": ["SW"], |
| "Age": ["25-34"], |
| "ICorPM": ["IC"], |
| "OrgSize": ["Small"], |
| "Employment": ["FT"], |
| } |
| ) |
| result = extract_valid_categories(df) |
| assert set(result.keys()) == { |
| "Country", |
| "EdLevel", |
| "DevType", |
| "Industry", |
| "Age", |
| "ICorPM", |
| "OrgSize", |
| "Employment", |
| } |
|
|
| def test_excludes_nan_values(self): |
| """NaN values are not included in valid categories.""" |
| df = pd.DataFrame( |
| { |
| "Country": ["USA", np.nan], |
| "EdLevel": ["BS", "MS"], |
| "DevType": ["Dev", "Dev"], |
| "Industry": ["SW", "SW"], |
| "Age": ["25-34", "25-34"], |
| "ICorPM": ["IC", "IC"], |
| "OrgSize": ["Small", "Small"], |
| "Employment": ["FT", "FT"], |
| } |
| ) |
| result = extract_valid_categories(df) |
| assert result["Country"] == ["USA"] |
|
|
|
|
| class TestComputeCurrencyRates: |
| """Tests for compute_currency_rates().""" |
|
|
| def test_computes_rates_for_valid_countries(self): |
| """Returns currency rates for countries present in the data.""" |
| df = pd.DataFrame( |
| { |
| "Country": ["USA", "USA", "Germany", "Germany"], |
| "Currency": [ |
| "USD United States Dollar", |
| "USD United States Dollar", |
| "EUR European Euro", |
| "EUR European Euro", |
| ], |
| "CompTotal": [100000.0, 120000.0, 80000.0, 90000.0], |
| "ConvertedCompYearly": [100000.0, 120000.0, 80000.0, 90000.0], |
| } |
| ) |
| result = compute_currency_rates(df, ["USA", "Germany"]) |
| assert "USA" in result |
| assert "Germany" in result |
| assert result["USA"]["code"] == "USD" |
| assert result["Germany"]["code"] == "EUR" |
| assert isinstance(result["USA"]["rate"], float) |
|
|
| def test_skips_countries_not_in_data(self): |
| """Countries not in the data are not included.""" |
| df = pd.DataFrame( |
| { |
| "Country": ["USA"], |
| "Currency": ["USD United States Dollar"], |
| "CompTotal": [100000.0], |
| "ConvertedCompYearly": [100000.0], |
| } |
| ) |
| result = compute_currency_rates(df, ["USA", "Narnia"]) |
| assert "USA" in result |
| assert "Narnia" not in result |
|
|
| def test_returns_dict_with_expected_keys(self): |
| """Each country entry has code, name, and rate keys.""" |
| df = pd.DataFrame( |
| { |
| "Country": ["USA", "USA"], |
| "Currency": [ |
| "USD United States Dollar", |
| "USD United States Dollar", |
| ], |
| "CompTotal": [100000.0, 100000.0], |
| "ConvertedCompYearly": [100000.0, 100000.0], |
| } |
| ) |
| result = compute_currency_rates(df, ["USA"]) |
| assert set(result["USA"].keys()) == {"code", "name", "rate"} |
|
|
| def test_filters_extreme_rates(self): |
| """Extreme conversion rates are filtered out.""" |
| df = pd.DataFrame( |
| { |
| "Country": ["USA", "USA"], |
| "Currency": [ |
| "USD United States Dollar", |
| "USD United States Dollar", |
| ], |
| "CompTotal": [100000.0, 0.0001], |
| "ConvertedCompYearly": [100000.0, 100000.0], |
| } |
| ) |
| result = compute_currency_rates(df, ["USA"]) |
| |
| assert result["USA"]["rate"] == 1.0 |
|
|
| def test_empty_dataframe(self): |
| """Returns empty dict for empty DataFrame.""" |
| df = pd.DataFrame( |
| columns=["Country", "Currency", "CompTotal", "ConvertedCompYearly"] |
| ) |
| result = compute_currency_rates(df, ["USA"]) |
| assert result == {} |
|
|