agripredict / update_all_models.py
ThejasRao's picture
Upload 5 files
3029a46 verified
"""
Script to train and update all models for India, States, and Markets.
Run this script to update all forecasting models without using the UI.
"""
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, mean_absolute_error
from xgboost import XGBRegressor
from tqdm import tqdm
from src.agri_predict import fetch_and_process_data
from src.agri_predict.constants import state_market_dict
from src.agri_predict.features import (
create_forecasting_features,
create_forecasting_features_1m,
create_forecasting_features_3m,
)
from src.agri_predict.config import get_collections
# Define forecast horizons
FORECAST_HORIZONS = [14, 30, 90] # 14 days, 1 month, 3 months
def train_model_batch(df, filter_key, days):
"""Train model without UI components for batch processing."""
cols = get_collections()
# Select feature creation function based on horizon
if days == 14:
df_features = create_forecasting_features(df)
split_date = '2024-01-01'
collection_key = 'best_params_collection'
elif days == 30:
df_features = create_forecasting_features_1m(df)
split_date = '2023-01-01'
collection_key = 'best_params_collection_1m'
else: # 90 days
df_features = create_forecasting_features_3m(df)
split_date = '2023-01-01'
collection_key = 'best_params_collection_3m'
# Split data
train_df = df_features[df_features['Reported Date'] < split_date]
test_df = df_features[df_features['Reported Date'] >= split_date]
X_train = train_df.drop(columns=['Modal Price (Rs./Quintal)', 'Reported Date'])
y_train = train_df['Modal Price (Rs./Quintal)']
X_test = test_df.drop(columns=['Modal Price (Rs./Quintal)', 'Reported Date'])
y_test = test_df['Modal Price (Rs./Quintal)']
# Hyperparameter tuning with progress bar
param_grid = {
'learning_rate': [0.01, 0.1, 0.2],
'max_depth': [3, 5, 7],
'n_estimators': [50, 100, 150],
'booster': ['gbtree', 'dart']
}
model = XGBRegressor()
best_score = float('-inf')
best_params = None
total_combinations = len(param_grid['learning_rate']) * len(param_grid['max_depth']) * \
len(param_grid['n_estimators']) * len(param_grid['booster'])
with tqdm(total=total_combinations, desc=f" Tuning hyperparameters") as pbar:
for learning_rate in param_grid['learning_rate']:
for max_depth in param_grid['max_depth']:
for n_estimators in param_grid['n_estimators']:
for booster in param_grid['booster']:
model.set_params(
learning_rate=learning_rate,
max_depth=max_depth,
n_estimators=n_estimators,
booster=booster
)
model.fit(X_train, y_train)
score = model.score(X_test, y_test)
if score > best_score:
best_score = score
best_params = {
'learning_rate': learning_rate,
'max_depth': max_depth,
'n_estimators': n_estimators,
'booster': booster
}
pbar.update(1)
# Train final model with best params
best_model = XGBRegressor(**best_params)
best_model.fit(X_train, y_train)
y_pred = best_model.predict(X_test)
# Calculate metrics
rmse = np.sqrt(mean_squared_error(y_test, y_pred))
mae = mean_absolute_error(y_test, y_pred)
# Save to MongoDB
cols[collection_key].replace_one(
{'filter_key': filter_key},
{
**best_params,
'filter_key': filter_key,
'last_updated': pd.Timestamp.now().isoformat(),
'rmse': rmse,
'mae': mae,
'score': best_score
},
upsert=True
)
return best_params, rmse, mae
def update_india_models():
"""Update models for all of India."""
print("\n" + "="*60)
print("UPDATING INDIA MODELS")
print("="*60)
query_filter = {}
df = fetch_and_process_data(query_filter)
if df is not None:
for days in FORECAST_HORIZONS:
horizon_name = "14 days" if days == 14 else "1 month" if days == 30 else "3 months"
print(f"\n[India] Training {horizon_name} forecast model...")
try:
best_params, rmse, mae = train_model_batch(df, "India", days)
print(f"βœ… [India] {horizon_name} model updated successfully")
print(f" RMSE: {rmse:.2f}, MAE: {mae:.2f}")
except Exception as e:
print(f"❌ [India] Error updating {horizon_name} model: {e}")
else:
print("❌ [India] No data available")
def update_state_models():
"""Update models for all states."""
print("\n" + "="*60)
print("UPDATING STATE MODELS")
print("="*60)
states = ["Karnataka", "Madhya Pradesh", "Gujarat", "Uttar Pradesh", "Telangana"]
for state in states:
print(f"\n--- Processing State: {state} ---")
query_filter = {"State Name": state}
df = fetch_and_process_data(query_filter)
if df is not None:
filter_key = f"state_{state}"
for days in FORECAST_HORIZONS:
horizon_name = "14 days" if days == 14 else "1 month" if days == 30 else "3 months"
print(f"[{state}] Training {horizon_name} forecast model...")
try:
best_params, rmse, mae = train_model_batch(df, filter_key, days)
print(f"βœ… [{state}] {horizon_name} model updated successfully")
print(f" RMSE: {rmse:.2f}, MAE: {mae:.2f}")
except Exception as e:
print(f"❌ [{state}] Error updating {horizon_name} model: {e}")
else:
print(f"❌ [{state}] No data available")
def update_market_models():
"""Update models for specific markets."""
print("\n" + "="*60)
print("UPDATING MARKET MODELS")
print("="*60)
markets = ["Rajkot", "Gondal", "Kalburgi", "Amreli"]
for market in markets:
print(f"\n--- Processing Market: {market} ---")
query_filter = {"Market Name": market}
df = fetch_and_process_data(query_filter)
if df is not None:
filter_key = f"market_{market}"
for days in FORECAST_HORIZONS:
horizon_name = "14 days" if days == 14 else "1 month" if days == 30 else "3 months"
print(f"[{market}] Training {horizon_name} forecast model...")
try:
best_params, rmse, mae = train_model_batch(df, filter_key, days)
print(f"βœ… [{market}] {horizon_name} model updated successfully")
print(f" RMSE: {rmse:.2f}, MAE: {mae:.2f}")
except Exception as e:
print(f"❌ [{market}] Error updating {horizon_name} model: {e}")
else:
print(f"❌ [{market}] No data available")
def main():
"""Main function to update all models."""
print("\n" + "🌾" * 30)
print("AGRIPREDICT - BATCH MODEL UPDATE")
print("🌾" * 30)
print("\nThis script will train and update all forecasting models.")
print("This may take several minutes to complete.\n")
try:
# Update India models
update_india_models()
# Update State models
update_state_models()
# Update Market models
update_market_models()
print("\n" + "="*60)
print("βœ… ALL MODELS UPDATED SUCCESSFULLY")
print("="*60)
except KeyboardInterrupt:
print("\n\n⚠️ Process interrupted by user")
except Exception as e:
print(f"\n\n❌ Fatal error: {e}")
if __name__ == "__main__":
main()