| | """ |
| | GCC Ramadan Retail Demand Forecasting - Inference Script |
| | |
| | This script demonstrates how to use the trained demand forecasting model. |
| | When downloaded from HuggingFace, this script works alongside model.joblib and encoders.joblib. |
| | |
| | Usage: |
| | python inference.py |
| | |
| | Or import and use programmatically: |
| | from inference import DemandForecaster |
| | forecaster = DemandForecaster() |
| | prediction = forecaster.predict(...) |
| | """ |
| |
|
| | import joblib |
| | import json |
| | import numpy as np |
| | import os |
| |
|
| | |
| | MODEL_DIR = os.path.dirname(os.path.abspath(__file__)) |
| |
|
| |
|
| | class DemandForecaster: |
| | """Class for loading and using the GCC Ramadan demand forecasting model.""" |
| |
|
| | def __init__(self, model_dir=None): |
| | """ |
| | Initialize the forecaster by loading the model and encoders. |
| | |
| | Args: |
| | model_dir: Path to directory containing model files. |
| | Defaults to same directory as this script. |
| | """ |
| | self.model_dir = model_dir or MODEL_DIR |
| | self._load_model() |
| |
|
| | def _load_model(self): |
| | """Load the trained model, encoders, and configuration.""" |
| | |
| | model_path = os.path.join(self.model_dir, "model.joblib") |
| | self.model = joblib.load(model_path) |
| |
|
| | |
| | encoders_path = os.path.join(self.model_dir, "encoders.joblib") |
| | encoders = joblib.load(encoders_path) |
| | self.country_encoder = encoders['country_encoder'] |
| | self.category_encoder = encoders['category_encoder'] |
| |
|
| | |
| | config_path = os.path.join(self.model_dir, "config.json") |
| | with open(config_path, 'r') as f: |
| | self.config = json.load(f) |
| |
|
| | self.countries = self.config['countries'] |
| | self.categories = self.config['categories'] |
| |
|
| | def predict(self, |
| | is_ramadan: int, |
| | ramadan_week: int, |
| | days_to_eid: int, |
| | is_eid_fitr: int, |
| | is_eid_adha: int, |
| | is_hajj_season: int, |
| | country: str, |
| | category: str, |
| | temperature: float, |
| | day_of_week: int, |
| | month: int, |
| | hijri_month: int, |
| | hijri_day: int) -> float: |
| | """ |
| | Predict demand index for given features. |
| | |
| | Args: |
| | is_ramadan: 1 if Ramadan, 0 otherwise |
| | ramadan_week: Week of Ramadan (1-5), 0 if not Ramadan |
| | days_to_eid: Days until Eid al-Fitr (-1 if not applicable) |
| | is_eid_fitr: 1 if Eid al-Fitr, 0 otherwise |
| | is_eid_adha: 1 if Eid al-Adha, 0 otherwise |
| | is_hajj_season: 1 if Hajj season, 0 otherwise |
| | country: One of UAE, KSA, Qatar, Kuwait, Bahrain, Oman |
| | category: Product category (dates_sweets, electronics, fashion_abayas, |
| | gifts, groceries, perfumes_oud) |
| | temperature: Temperature in Celsius |
| | day_of_week: Day of week (0-6, Monday=0) |
| | month: Gregorian month (1-12) |
| | hijri_month: Hijri month (1-12) |
| | hijri_day: Hijri day (1-30) |
| | |
| | Returns: |
| | Predicted demand index (typically 30-200 range) |
| | """ |
| | |
| | if country not in self.countries: |
| | raise ValueError(f"Invalid country: {country}. Must be one of {self.countries}") |
| | if category not in self.categories: |
| | raise ValueError(f"Invalid category: {category}. Must be one of {self.categories}") |
| |
|
| | |
| | country_encoded = self.country_encoder.transform([country])[0] |
| | category_encoded = self.category_encoder.transform([category])[0] |
| |
|
| | |
| | features = np.array([[ |
| | is_ramadan, |
| | ramadan_week, |
| | days_to_eid, |
| | is_eid_fitr, |
| | is_eid_adha, |
| | is_hajj_season, |
| | country_encoded, |
| | category_encoded, |
| | temperature, |
| | day_of_week, |
| | month, |
| | hijri_month, |
| | hijri_day |
| | ]]) |
| |
|
| | |
| | prediction = self.model.predict(features)[0] |
| | return prediction |
| |
|
| | def predict_dict(self, data: dict) -> float: |
| | """ |
| | Predict demand index from a dictionary of features. |
| | |
| | Args: |
| | data: Dictionary with keys matching the predict() parameters |
| | |
| | Returns: |
| | Predicted demand index |
| | """ |
| | return self.predict(**data) |
| |
|
| | def predict_batch(self, data_list: list) -> list: |
| | """ |
| | Predict demand index for multiple records. |
| | |
| | Args: |
| | data_list: List of dictionaries with feature values |
| | |
| | Returns: |
| | List of predicted demand indices |
| | """ |
| | return [self.predict(**record) for record in data_list] |
| |
|
| |
|
| | def demo(): |
| | """Demonstrate the model with example predictions.""" |
| | print("=" * 60) |
| | print("GCC Ramadan Retail Demand Forecasting - Demo") |
| | print("=" * 60) |
| |
|
| | |
| | forecaster = DemandForecaster() |
| |
|
| | print(f"\nAvailable countries: {forecaster.countries}") |
| | print(f"Available categories: {forecaster.categories}") |
| | print(f"\nModel metrics: R2={forecaster.config['metrics']['r2_score']:.3f}, " |
| | f"RMSE={forecaster.config['metrics']['rmse']:.2f}") |
| |
|
| | print("\n" + "-" * 60) |
| | print("Example Predictions:") |
| | print("-" * 60) |
| |
|
| | examples = [ |
| | { |
| | "name": "Normal day in UAE (groceries)", |
| | "params": { |
| | "is_ramadan": 0, "ramadan_week": 0, "days_to_eid": -1, |
| | "is_eid_fitr": 0, "is_eid_adha": 0, "is_hajj_season": 0, |
| | "country": "UAE", "category": "groceries", "temperature": 25.0, |
| | "day_of_week": 5, "month": 6, "hijri_month": 11, "hijri_day": 15 |
| | } |
| | }, |
| | { |
| | "name": "Ramadan Week 2 in KSA (dates_sweets)", |
| | "params": { |
| | "is_ramadan": 1, "ramadan_week": 2, "days_to_eid": 15, |
| | "is_eid_fitr": 0, "is_eid_adha": 0, "is_hajj_season": 0, |
| | "country": "KSA", "category": "dates_sweets", "temperature": 30.0, |
| | "day_of_week": 4, "month": 4, "hijri_month": 9, "hijri_day": 15 |
| | } |
| | }, |
| | { |
| | "name": "Eid al-Fitr in Qatar (gifts)", |
| | "params": { |
| | "is_ramadan": 0, "ramadan_week": 0, "days_to_eid": 0, |
| | "is_eid_fitr": 1, "is_eid_adha": 0, "is_hajj_season": 0, |
| | "country": "Qatar", "category": "gifts", "temperature": 35.0, |
| | "day_of_week": 0, "month": 5, "hijri_month": 10, "hijri_day": 1 |
| | } |
| | }, |
| | { |
| | "name": "Hajj season in KSA (perfumes_oud)", |
| | "params": { |
| | "is_ramadan": 0, "ramadan_week": 0, "days_to_eid": -1, |
| | "is_eid_fitr": 0, "is_eid_adha": 0, "is_hajj_season": 1, |
| | "country": "KSA", "category": "perfumes_oud", "temperature": 40.0, |
| | "day_of_week": 3, "month": 7, "hijri_month": 12, "hijri_day": 8 |
| | } |
| | }, |
| | { |
| | "name": "Eid al-Adha in Kuwait (fashion_abayas)", |
| | "params": { |
| | "is_ramadan": 0, "ramadan_week": 0, "days_to_eid": -1, |
| | "is_eid_fitr": 0, "is_eid_adha": 1, "is_hajj_season": 1, |
| | "country": "Kuwait", "category": "fashion_abayas", "temperature": 42.0, |
| | "day_of_week": 5, "month": 7, "hijri_month": 12, "hijri_day": 10 |
| | } |
| | } |
| | ] |
| |
|
| | for i, example in enumerate(examples, 1): |
| | pred = forecaster.predict(**example["params"]) |
| | print(f"\n{i}. {example['name']}: {pred:.2f}") |
| |
|
| | print("\n" + "=" * 60) |
| | print("Demo complete!") |
| | print("=" * 60) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | demo() |
| |
|