agripredict / streamlit_app.py
ThejasRao's picture
Fix: Readme
a820271
"""Streamlit entrypoint for AgriPredict (refactored).
Run with: `streamlit run streamlit_app.py` from project root.
"""
import streamlit as st
from datetime import datetime, timedelta
import pandas as pd
from sklearn.preprocessing import MinMaxScaler
import os
from dotenv import load_dotenv
from src.agri_predict import (
fetch_and_process_data,
fetch_and_store_data,
preprocess_data,
train_and_forecast,
forecast,
collection_to_dataframe,
get_dataframe_from_collection,
)
from src.agri_predict.constants import state_market_dict
from src.agri_predict.utils import authenticate_user
from src.agri_predict.config import get_collections
# Load environment variables
load_dotenv()
IS_PROD = os.getenv("PROD", "False").lower() == "true"
st.set_page_config(layout="wide")
def load_all_data(_collection):
"""Load all data from MongoDB collection."""
data = list(_collection.find({}))
if not data:
return pd.DataFrame()
df = pd.DataFrame(data)
# Drop MongoDB _id field
if '_id' in df.columns:
df = df.drop(columns=['_id'])
# Convert data types
if 'Reported Date' in df.columns:
df['Reported Date'] = pd.to_datetime(df['Reported Date'])
if 'Modal Price (Rs./Quintal)' in df.columns:
df['Modal Price (Rs./Quintal)'] = pd.to_numeric(df['Modal Price (Rs./Quintal)'], errors='coerce')
if 'Arrivals (Tonnes)' in df.columns:
df['Arrivals (Tonnes)'] = pd.to_numeric(df['Arrivals (Tonnes)'], errors='coerce')
return df
def get_filtered_data(_collection, state=None, market=None, days=30):
"""Get filtered data based on parameters."""
query_filter = {"Reported Date": {"$gte": datetime.now() - timedelta(days=days)}}
if state and state != 'India':
query_filter["State Name"] = state
if market:
query_filter["Market Name"] = market
data = list(_collection.find(query_filter))
if not data:
return pd.DataFrame()
df = pd.DataFrame(data)
# Drop MongoDB _id field
if '_id' in df.columns:
df = df.drop(columns=['_id'])
# Convert data types
if 'Reported Date' in df.columns:
df['Reported Date'] = pd.to_datetime(df['Reported Date'])
if 'Modal Price (Rs./Quintal)' in df.columns:
df['Modal Price (Rs./Quintal)'] = pd.to_numeric(df['Modal Price (Rs./Quintal)'], errors='coerce')
if 'Arrivals (Tonnes)' in df.columns:
df['Arrivals (Tonnes)'] = pd.to_numeric(df['Arrivals (Tonnes)'], errors='coerce')
return df
st.markdown("""
<style>
.main {
max-width: 1200px;
margin: 0 auto;
padding: 2rem;
}
.block-container {
max-width: 1200px;
padding-left: 5rem;
padding-right: 5rem;
}
h1 { color: #4CAF50; font-family: 'Arial Black', sans-serif; }
.stButton>button { background-color: #4CAF50; color: white; }
</style>
""", unsafe_allow_html=True)
if 'authenticated' not in st.session_state:
st.session_state.authenticated = False
if st.session_state.authenticated:
# Get collections after authentication
try:
cols = get_collections()
except Exception as exc:
st.error(f"Configuration error: {exc}")
st.stop()
collection = cols['collection']
impExp = cols['impExp']
st.title("🌾 AgriPredict Dashboard")
if st.button("Get Live Data Feed"):
fetch_and_store_data()
view_mode = st.radio("View Mode", ["Statistics", "Plots", "Predictions", "Exim"], horizontal=True, label_visibility="collapsed")
if view_mode == "Plots":
st.sidebar.header("Filters")
selected_period = st.sidebar.selectbox("Select Time Period", ["2 Weeks", "1 Month", "3 Months", "1 Year", "5 Years"], index=1)
period_mapping = {"2 Weeks": 14, "1 Month": 30, "3 Months": 90, "1 Year": 365, "2 Years": 730, "5 Years": 1825}
st.session_state.selected_period = period_mapping[selected_period]
state_options = list(state_market_dict.keys()) + ['India']
selected_state = st.sidebar.selectbox("Select", state_options)
market_wise = False
if selected_state != 'India':
market_wise = st.sidebar.checkbox("Market Wise Analysis")
if market_wise:
markets = state_market_dict.get(selected_state, [])
selected_market = st.sidebar.selectbox("Select Market", markets)
query_filter = {"State Name": selected_state, "Market Name": selected_market}
else:
query_filter = {"State Name": selected_state}
else:
query_filter = {}
data_type = st.sidebar.radio("Select Data Type", ["Price", "Volume", "Both"])
if st.sidebar.button("✨ Let's go!"):
try:
# Load data
state_param = selected_state if selected_state != 'India' else None
market_param = selected_market if market_wise else None
df = get_filtered_data(collection, state_param, market_param, st.session_state.selected_period)
if not df.empty:
# Group by date and aggregate
df_grouped = df.groupby('Reported Date', as_index=False).agg({
'Arrivals (Tonnes)': 'sum',
'Modal Price (Rs./Quintal)': 'mean'
})
# Create complete date range and fill gaps
date_range = pd.date_range(
start=df_grouped['Reported Date'].min(),
end=df_grouped['Reported Date'].max(),
freq='D'
)
df_grouped = df_grouped.set_index('Reported Date').reindex(date_range).rename_axis('Reported Date').reset_index()
# Fill missing values using the working method
df_grouped['Arrivals (Tonnes)'] = df_grouped['Arrivals (Tonnes)'].ffill().bfill()
df_grouped['Modal Price (Rs./Quintal)'] = df_grouped['Modal Price (Rs./Quintal)'].ffill().bfill()
st.subheader(f"📈 Trends for {selected_state} ({'Market: ' + selected_market if market_wise else 'State'})")
if data_type == "Both":
# Min-Max Scaling
scaler = MinMaxScaler()
df_grouped[['Scaled Price', 'Scaled Arrivals']] = scaler.fit_transform(
df_grouped[['Modal Price (Rs./Quintal)', 'Arrivals (Tonnes)']]
)
import plotly.graph_objects as go
fig = go.Figure()
fig.add_trace(go.Scatter(
x=df_grouped['Reported Date'],
y=df_grouped['Scaled Price'],
mode='lines',
name='Scaled Price',
line=dict(width=1, color='green'),
text=df_grouped['Modal Price (Rs./Quintal)'],
hovertemplate='Date: %{x}<br>Scaled Price: %{y:.2f}<br>Actual Price: %{text:.2f}<extra></extra>'
))
fig.add_trace(go.Scatter(
x=df_grouped['Reported Date'],
y=df_grouped['Scaled Arrivals'],
mode='lines',
name='Scaled Arrivals',
line=dict(width=1, color='blue'),
text=df_grouped['Arrivals (Tonnes)'],
hovertemplate='Date: %{x}<br>Scaled Arrivals: %{y:.2f}<br>Actual Arrivals: %{text:.2f}<extra></extra>'
))
fig.update_layout(
title="Price and Arrivals Trend",
xaxis_title='Date',
yaxis_title='Scaled Values',
template='plotly_white'
)
st.plotly_chart(fig, use_container_width=True)
elif data_type == "Price":
# Plot Modal Price
import plotly.graph_objects as go
fig = go.Figure()
fig.add_trace(go.Scatter(
x=df_grouped['Reported Date'],
y=df_grouped['Modal Price (Rs./Quintal)'],
mode='lines',
name='Modal Price',
line=dict(width=1, color='green')
))
fig.update_layout(title="Modal Price Trend", xaxis_title='Date', yaxis_title='Price (/Quintall)', template='plotly_white')
st.plotly_chart(fig, use_container_width=True)
elif data_type == "Volume":
# Plot Arrivals (Tonnes)
import plotly.graph_objects as go
fig = go.Figure()
fig.add_trace(go.Scatter(
x=df_grouped['Reported Date'],
y=df_grouped['Arrivals (Tonnes)'],
mode='lines',
name='Arrivals',
line=dict(width=1, color='blue')
))
fig.update_layout(title="Arrivals Trend", xaxis_title='Date', yaxis_title='Volume (in Tonnes)', template='plotly_white')
st.plotly_chart(fig, use_container_width=True)
else:
st.warning("⚠️ No data found for the selected filters.")
except Exception as e:
st.error(f"❌ Error fetching data: {e}")
elif view_mode == "Predictions":
st.subheader("📊 Model Analysis")
sub_option = st.radio("Select one of the following", ["India", "States", "Market"], horizontal=True)
sub_timeline = st.radio("Select one of the following horizons", ["14 days", "1 month", "3 month"], horizontal=True)
if sub_option == "States":
states = ["Karnataka", "Madhya Pradesh", "Gujarat", "Uttar Pradesh", "Telangana"]
selected_state = st.selectbox("Select State for Model Training", states)
filter_key = f"state_{selected_state}"
if not IS_PROD and st.button("Train and Forecast"):
query_filter = {"State Name": selected_state}
df = fetch_and_process_data(query_filter)
if df is not None:
if sub_timeline == "14 days":
train_and_forecast(df, filter_key, 14)
elif sub_timeline == "1 month":
train_and_forecast(df, filter_key, 30)
else:
train_and_forecast(df, filter_key, 90)
else:
st.error("❌ No data available for the selected state.")
if st.button("Forecast"):
query_filter = {"State Name": selected_state}
df = fetch_and_process_data(query_filter)
if df is not None:
if sub_timeline == "14 days":
forecast(df, filter_key, 14)
elif sub_timeline == "1 month":
forecast(df, filter_key, 30)
else:
forecast(df, filter_key, 90)
else:
st.error("❌ No data available for the selected state.")
elif sub_option == "Market":
market_options = ["Rajkot", "Gondal", "Kalburgi", "Amreli"]
selected_market = st.selectbox("Select Market for Model Training", market_options)
filter_key = f"market_{selected_market}"
if not IS_PROD and st.button("Train and Forecast"):
query_filter = {"Market Name": selected_market}
df = fetch_and_process_data(query_filter)
if df is not None:
if sub_timeline == "14 days":
train_and_forecast(df, filter_key, 14)
elif sub_timeline == "1 month":
train_and_forecast(df, filter_key, 30)
else:
train_and_forecast(df, filter_key, 90)
else:
st.error("❌ No data available for the selected market.")
elif st.button("Forecast"):
query_filter = {"Market Name": selected_market}
df = fetch_and_process_data(query_filter)
if df is not None:
if sub_timeline == "14 days":
forecast(df, filter_key, 14)
elif sub_timeline == "1 month":
forecast(df, filter_key, 30)
else:
forecast(df, filter_key, 90)
else:
st.error("❌ No data available for the selected market.")
elif sub_option == "India":
df = collection_to_dataframe(impExp)
if not IS_PROD and st.button("Train and Forecast"):
query_filter = {}
df = fetch_and_process_data(query_filter)
if df is not None:
if sub_timeline == "14 days":
train_and_forecast(df, "India", 14)
elif sub_timeline == "1 month":
train_and_forecast(df, "India", 30)
else:
train_and_forecast(df, "India", 90)
else:
st.error("❌ No data available for forecasting.")
if st.button("Forecast"):
query_filter = {}
df = fetch_and_process_data(query_filter)
if df is not None:
if sub_timeline == "14 days":
forecast(df, "India", 14)
elif sub_timeline == "1 month":
forecast(df, "India", 30)
else:
forecast(df, "India", 90)
else:
st.error("❌ No data available for forecasting.")
elif view_mode == "Statistics":
# Use cached data loading
df = load_all_data(collection)
if not df.empty:
from src.agri_predict.plotting import display_statistics
display_statistics(df)
else:
st.warning("No data available to display statistics.")
elif view_mode == "Exim":
df = collection_to_dataframe(impExp)
plot_option = st.radio("Select the data to visualize:", ["Import Price", "Import Quantity", "Export Price", "Export Quantity"], horizontal=True)
time_period = st.selectbox("Select time period:", ["1 Month", "6 Months", "1 Year", "2 Years"])
df["Reported Date"] = pd.to_datetime(df["Reported Date"], format="%Y-%m-%d")
if time_period == "1 Month":
start_date = pd.Timestamp.now() - pd.DateOffset(months=1)
elif time_period == "6 Months":
start_date = pd.Timestamp.now() - pd.DateOffset(months=6)
elif time_period == "1 Year":
start_date = pd.Timestamp.now() - pd.DateOffset(years=1)
else:
start_date = pd.Timestamp.now() - pd.DateOffset(years=2)
filtered_df = df[df["Reported Date"] >= start_date]
if plot_option == "Import Price":
grouped_df = filtered_df.groupby("Reported Date", as_index=False)["VALUE_IMPORT"].mean().rename(columns={"VALUE_IMPORT": "Average Import Price"})
y_axis_label = "Average Import Price (Rs.)"
elif plot_option == "Import Quantity":
grouped_df = filtered_df.groupby("Reported Date", as_index=False)["QUANTITY_IMPORT"].sum().rename(columns={"QUANTITY_IMPORT": "Total Import Quantity"})
y_axis_label = "Total Import Quantity (Tonnes)"
elif plot_option == "Export Price":
grouped_df = filtered_df.groupby("Reported Date", as_index=False)["VALUE_EXPORT"].mean().rename(columns={"VALUE_EXPORT": "Average Export Price"})
y_axis_label = "Average Export Price (Rs.)"
else:
grouped_df = filtered_df.groupby("Reported Date", as_index=False)["QUANTITY_IMPORT"].sum().rename(columns={"QUANTITY_IMPORT": "Total Export Quantity"})
y_axis_label = "Total Export Quantity (Tonnes)"
import plotly.express as px
fig = px.line(grouped_df, x="Reported Date", y=grouped_df.columns[1], title=f"{plot_option} Over Time", labels={"Reported Date": "Date", grouped_df.columns[1]: y_axis_label})
st.plotly_chart(fig)
else:
with st.form("login_form"):
st.subheader("Please log in")
username = st.text_input("Username")
password = st.text_input("Password", type="password")
login_button = st.form_submit_button("Login")
if login_button:
# Get collections for authentication
try:
cols = get_collections()
users_collection = cols['users_collection']
except Exception as exc:
st.error(f"Database connection error: {exc}")
st.stop()
if authenticate_user(username, password, users_collection):
st.session_state.authenticated = True
st.session_state['username'] = username
st.write("Login successful!")
st.rerun()
else:
st.error("Invalid username or password")