Spaces:
Sleeping
Sleeping
| """ | |
| Script to train and update all models for India, States, and Markets. | |
| Run this script to update all forecasting models without using the UI. | |
| """ | |
| import numpy as np | |
| import pandas as pd | |
| from sklearn.model_selection import train_test_split | |
| from sklearn.metrics import mean_squared_error, mean_absolute_error | |
| from xgboost import XGBRegressor | |
| from tqdm import tqdm | |
| from src.agri_predict import fetch_and_process_data | |
| from src.agri_predict.constants import state_market_dict | |
| from src.agri_predict.features import ( | |
| create_forecasting_features, | |
| create_forecasting_features_1m, | |
| create_forecasting_features_3m, | |
| ) | |
| from src.agri_predict.config import get_collections | |
| # Define forecast horizons | |
| FORECAST_HORIZONS = [14, 30, 90] # 14 days, 1 month, 3 months | |
| def train_model_batch(df, filter_key, days): | |
| """Train model without UI components for batch processing.""" | |
| cols = get_collections() | |
| # Select feature creation function based on horizon | |
| if days == 14: | |
| df_features = create_forecasting_features(df) | |
| split_date = '2024-01-01' | |
| collection_key = 'best_params_collection' | |
| elif days == 30: | |
| df_features = create_forecasting_features_1m(df) | |
| split_date = '2023-01-01' | |
| collection_key = 'best_params_collection_1m' | |
| else: # 90 days | |
| df_features = create_forecasting_features_3m(df) | |
| split_date = '2023-01-01' | |
| collection_key = 'best_params_collection_3m' | |
| # Split data | |
| train_df = df_features[df_features['Reported Date'] < split_date] | |
| test_df = df_features[df_features['Reported Date'] >= split_date] | |
| X_train = train_df.drop(columns=['Modal Price (Rs./Quintal)', 'Reported Date']) | |
| y_train = train_df['Modal Price (Rs./Quintal)'] | |
| X_test = test_df.drop(columns=['Modal Price (Rs./Quintal)', 'Reported Date']) | |
| y_test = test_df['Modal Price (Rs./Quintal)'] | |
| # Hyperparameter tuning with progress bar | |
| param_grid = { | |
| 'learning_rate': [0.01, 0.1, 0.2], | |
| 'max_depth': [3, 5, 7], | |
| 'n_estimators': [50, 100, 150], | |
| 'booster': ['gbtree', 'dart'] | |
| } | |
| model = XGBRegressor() | |
| best_score = float('-inf') | |
| best_params = None | |
| total_combinations = len(param_grid['learning_rate']) * len(param_grid['max_depth']) * \ | |
| len(param_grid['n_estimators']) * len(param_grid['booster']) | |
| with tqdm(total=total_combinations, desc=f" Tuning hyperparameters") as pbar: | |
| for learning_rate in param_grid['learning_rate']: | |
| for max_depth in param_grid['max_depth']: | |
| for n_estimators in param_grid['n_estimators']: | |
| for booster in param_grid['booster']: | |
| model.set_params( | |
| learning_rate=learning_rate, | |
| max_depth=max_depth, | |
| n_estimators=n_estimators, | |
| booster=booster | |
| ) | |
| model.fit(X_train, y_train) | |
| score = model.score(X_test, y_test) | |
| if score > best_score: | |
| best_score = score | |
| best_params = { | |
| 'learning_rate': learning_rate, | |
| 'max_depth': max_depth, | |
| 'n_estimators': n_estimators, | |
| 'booster': booster | |
| } | |
| pbar.update(1) | |
| # Train final model with best params | |
| best_model = XGBRegressor(**best_params) | |
| best_model.fit(X_train, y_train) | |
| y_pred = best_model.predict(X_test) | |
| # Calculate metrics | |
| rmse = np.sqrt(mean_squared_error(y_test, y_pred)) | |
| mae = mean_absolute_error(y_test, y_pred) | |
| # Save to MongoDB | |
| cols[collection_key].replace_one( | |
| {'filter_key': filter_key}, | |
| { | |
| **best_params, | |
| 'filter_key': filter_key, | |
| 'last_updated': pd.Timestamp.now().isoformat(), | |
| 'rmse': rmse, | |
| 'mae': mae, | |
| 'score': best_score | |
| }, | |
| upsert=True | |
| ) | |
| return best_params, rmse, mae | |
| def update_india_models(): | |
| """Update models for all of India.""" | |
| print("\n" + "="*60) | |
| print("UPDATING INDIA MODELS") | |
| print("="*60) | |
| query_filter = {} | |
| df = fetch_and_process_data(query_filter) | |
| if df is not None: | |
| for days in FORECAST_HORIZONS: | |
| horizon_name = "14 days" if days == 14 else "1 month" if days == 30 else "3 months" | |
| print(f"\n[India] Training {horizon_name} forecast model...") | |
| try: | |
| best_params, rmse, mae = train_model_batch(df, "India", days) | |
| print(f"β [India] {horizon_name} model updated successfully") | |
| print(f" RMSE: {rmse:.2f}, MAE: {mae:.2f}") | |
| except Exception as e: | |
| print(f"β [India] Error updating {horizon_name} model: {e}") | |
| else: | |
| print("β [India] No data available") | |
| def update_state_models(): | |
| """Update models for all states.""" | |
| print("\n" + "="*60) | |
| print("UPDATING STATE MODELS") | |
| print("="*60) | |
| states = ["Karnataka", "Madhya Pradesh", "Gujarat", "Uttar Pradesh", "Telangana"] | |
| for state in states: | |
| print(f"\n--- Processing State: {state} ---") | |
| query_filter = {"State Name": state} | |
| df = fetch_and_process_data(query_filter) | |
| if df is not None: | |
| filter_key = f"state_{state}" | |
| for days in FORECAST_HORIZONS: | |
| horizon_name = "14 days" if days == 14 else "1 month" if days == 30 else "3 months" | |
| print(f"[{state}] Training {horizon_name} forecast model...") | |
| try: | |
| best_params, rmse, mae = train_model_batch(df, filter_key, days) | |
| print(f"β [{state}] {horizon_name} model updated successfully") | |
| print(f" RMSE: {rmse:.2f}, MAE: {mae:.2f}") | |
| except Exception as e: | |
| print(f"β [{state}] Error updating {horizon_name} model: {e}") | |
| else: | |
| print(f"β [{state}] No data available") | |
| def update_market_models(): | |
| """Update models for specific markets.""" | |
| print("\n" + "="*60) | |
| print("UPDATING MARKET MODELS") | |
| print("="*60) | |
| markets = ["Rajkot", "Gondal", "Kalburgi", "Amreli"] | |
| for market in markets: | |
| print(f"\n--- Processing Market: {market} ---") | |
| query_filter = {"Market Name": market} | |
| df = fetch_and_process_data(query_filter) | |
| if df is not None: | |
| filter_key = f"market_{market}" | |
| for days in FORECAST_HORIZONS: | |
| horizon_name = "14 days" if days == 14 else "1 month" if days == 30 else "3 months" | |
| print(f"[{market}] Training {horizon_name} forecast model...") | |
| try: | |
| best_params, rmse, mae = train_model_batch(df, filter_key, days) | |
| print(f"β [{market}] {horizon_name} model updated successfully") | |
| print(f" RMSE: {rmse:.2f}, MAE: {mae:.2f}") | |
| except Exception as e: | |
| print(f"β [{market}] Error updating {horizon_name} model: {e}") | |
| else: | |
| print(f"β [{market}] No data available") | |
| def main(): | |
| """Main function to update all models.""" | |
| print("\n" + "πΎ" * 30) | |
| print("AGRIPREDICT - BATCH MODEL UPDATE") | |
| print("πΎ" * 30) | |
| print("\nThis script will train and update all forecasting models.") | |
| print("This may take several minutes to complete.\n") | |
| try: | |
| # Update India models | |
| update_india_models() | |
| # Update State models | |
| update_state_models() | |
| # Update Market models | |
| update_market_models() | |
| print("\n" + "="*60) | |
| print("β ALL MODELS UPDATED SUCCESSFULLY") | |
| print("="*60) | |
| except KeyboardInterrupt: | |
| print("\n\nβ οΈ Process interrupted by user") | |
| except Exception as e: | |
| print(f"\n\nβ Fatal error: {e}") | |
| if __name__ == "__main__": | |
| main() | |