| """Inference utilities for salary prediction.""" |
|
|
| import pickle |
| from pathlib import Path |
|
|
| import pandas as pd |
| import yaml |
|
|
| from src.schema import SalaryInput |
| from src.preprocessing import prepare_features |
|
|
| |
| model_path = Path("models/model.pkl") |
|
|
| if not model_path.exists(): |
| raise FileNotFoundError( |
| f"Model file not found at {model_path}. Please run 'python -m src.train' first." |
| ) |
|
|
| with open(model_path, "rb") as f: |
| artifacts = pickle.load(f) |
| model = artifacts["model"] |
| feature_columns = artifacts["feature_columns"] |
|
|
| |
| valid_categories_path = Path("config/valid_categories.yaml") |
|
|
| if not valid_categories_path.exists(): |
| raise FileNotFoundError( |
| f"Valid categories file not found at {valid_categories_path}. Please run 'python -m src.train' first." |
| ) |
|
|
| with open(valid_categories_path, "r") as f: |
| valid_categories = yaml.safe_load(f) |
|
|
| |
| currency_rates_path = Path("config/currency_rates.yaml") |
| currency_rates = {} |
| if currency_rates_path.exists(): |
| with open(currency_rates_path, "r") as f: |
| currency_rates = yaml.safe_load(f) or {} |
|
|
|
|
| def get_local_currency(country: str, salary_usd: float) -> dict | None: |
| """Convert USD salary to local currency for a given country. |
| |
| Returns: |
| Dict with code, name, rate, and salary_local, or None if unavailable. |
| """ |
| if country not in currency_rates: |
| return None |
| info = currency_rates[country] |
| return { |
| "code": info["code"], |
| "name": info["name"], |
| "rate": info["rate"], |
| "salary_local": round(salary_usd * info["rate"], 2), |
| } |
|
|
|
|
| def predict_salary(data: SalaryInput) -> float: |
| """Predict salary based on input features. |
| |
| Args: |
| data: SalaryInput model with developer information |
| |
| Returns: |
| Predicted annual salary in USD |
| |
| Raises: |
| ValueError: If country or education_level is not in valid categories |
| """ |
| |
| if data.country not in valid_categories["Country"]: |
| raise ValueError( |
| f"Invalid country: '{data.country}'. " |
| f"Must be one of {len(valid_categories['Country'])} valid countries. " |
| f"Check config/valid_categories.yaml for all valid values." |
| ) |
|
|
| if data.education_level not in valid_categories["EdLevel"]: |
| raise ValueError( |
| f"Invalid education level: '{data.education_level}'. " |
| f"Must be one of {len(valid_categories['EdLevel'])} valid education levels. " |
| f"Check config/valid_categories.yaml for all valid values." |
| ) |
|
|
| if data.dev_type not in valid_categories["DevType"]: |
| raise ValueError( |
| f"Invalid developer type: '{data.dev_type}'. " |
| f"Must be one of {len(valid_categories['DevType'])} valid developer types. " |
| f"Check config/valid_categories.yaml for all valid values." |
| ) |
|
|
| if data.industry not in valid_categories["Industry"]: |
| raise ValueError( |
| f"Invalid industry: '{data.industry}'. " |
| f"Must be one of {len(valid_categories['Industry'])} valid industries. " |
| f"Check config/valid_categories.yaml for all valid values." |
| ) |
|
|
| if data.age not in valid_categories["Age"]: |
| raise ValueError( |
| f"Invalid age: '{data.age}'. " |
| f"Must be one of {len(valid_categories['Age'])} valid age ranges. " |
| f"Check config/valid_categories.yaml for all valid values." |
| ) |
|
|
| if data.ic_or_pm not in valid_categories["ICorPM"]: |
| raise ValueError( |
| f"Invalid IC or PM value: '{data.ic_or_pm}'. " |
| f"Must be one of {len(valid_categories['ICorPM'])} valid values. " |
| f"Check config/valid_categories.yaml for all valid values." |
| ) |
|
|
| if data.org_size not in valid_categories["OrgSize"]: |
| raise ValueError( |
| f"Invalid organization size: '{data.org_size}'. " |
| f"Must be one of {len(valid_categories['OrgSize'])} valid sizes. " |
| f"Check config/valid_categories.yaml for all valid values." |
| ) |
|
|
| if data.employment not in valid_categories["Employment"]: |
| raise ValueError( |
| f"Invalid employment status: '{data.employment}'. " |
| f"Must be one of {valid_categories['Employment']}. " |
| f"Check config/valid_categories.yaml for all valid values." |
| ) |
|
|
| |
| input_df = pd.DataFrame( |
| { |
| "Country": [data.country], |
| "YearsCode": [data.years_code], |
| "WorkExp": [data.work_exp], |
| "EdLevel": [data.education_level], |
| "DevType": [data.dev_type], |
| "Industry": [data.industry], |
| "Age": [data.age], |
| "ICorPM": [data.ic_or_pm], |
| "OrgSize": [data.org_size], |
| "Employment": [data.employment], |
| } |
| ) |
|
|
| |
| input_encoded = prepare_features(input_df) |
|
|
| |
| |
| input_encoded = input_encoded.reindex(columns=feature_columns, fill_value=0) |
|
|
| |
| prediction = model.predict(input_encoded)[0] |
|
|
| |
| return max(0.0, float(prediction)) |
|
|