Spaces:
Sleeping
Sleeping
| """Streamlit entrypoint for AgriPredict (refactored). | |
| Run with: `streamlit run streamlit_app.py` from project root. | |
| """ | |
| import streamlit as st | |
| from datetime import datetime, timedelta | |
| import pandas as pd | |
| from sklearn.preprocessing import MinMaxScaler | |
| import os | |
| from dotenv import load_dotenv | |
| from src.agri_predict import ( | |
| fetch_and_process_data, | |
| fetch_and_store_data, | |
| preprocess_data, | |
| train_and_forecast, | |
| forecast, | |
| collection_to_dataframe, | |
| get_dataframe_from_collection, | |
| ) | |
| from src.agri_predict.constants import state_market_dict | |
| from src.agri_predict.utils import authenticate_user | |
| from src.agri_predict.config import get_collections | |
| # Load environment variables | |
| load_dotenv() | |
| IS_PROD = os.getenv("PROD", "False").lower() == "true" | |
| st.set_page_config(layout="wide") | |
| def load_all_data(_collection): | |
| """Load all data from MongoDB collection.""" | |
| data = list(_collection.find({})) | |
| if not data: | |
| return pd.DataFrame() | |
| df = pd.DataFrame(data) | |
| # Drop MongoDB _id field | |
| if '_id' in df.columns: | |
| df = df.drop(columns=['_id']) | |
| # Convert data types | |
| if 'Reported Date' in df.columns: | |
| df['Reported Date'] = pd.to_datetime(df['Reported Date']) | |
| if 'Modal Price (Rs./Quintal)' in df.columns: | |
| df['Modal Price (Rs./Quintal)'] = pd.to_numeric(df['Modal Price (Rs./Quintal)'], errors='coerce') | |
| if 'Arrivals (Tonnes)' in df.columns: | |
| df['Arrivals (Tonnes)'] = pd.to_numeric(df['Arrivals (Tonnes)'], errors='coerce') | |
| return df | |
| def get_filtered_data(_collection, state=None, market=None, days=30): | |
| """Get filtered data based on parameters.""" | |
| query_filter = {"Reported Date": {"$gte": datetime.now() - timedelta(days=days)}} | |
| if state and state != 'India': | |
| query_filter["State Name"] = state | |
| if market: | |
| query_filter["Market Name"] = market | |
| data = list(_collection.find(query_filter)) | |
| if not data: | |
| return pd.DataFrame() | |
| df = pd.DataFrame(data) | |
| # Drop MongoDB _id field | |
| if '_id' in df.columns: | |
| df = df.drop(columns=['_id']) | |
| # Convert data types | |
| if 'Reported Date' in df.columns: | |
| df['Reported Date'] = pd.to_datetime(df['Reported Date']) | |
| if 'Modal Price (Rs./Quintal)' in df.columns: | |
| df['Modal Price (Rs./Quintal)'] = pd.to_numeric(df['Modal Price (Rs./Quintal)'], errors='coerce') | |
| if 'Arrivals (Tonnes)' in df.columns: | |
| df['Arrivals (Tonnes)'] = pd.to_numeric(df['Arrivals (Tonnes)'], errors='coerce') | |
| return df | |
| st.markdown(""" | |
| <style> | |
| .main { | |
| max-width: 1200px; | |
| margin: 0 auto; | |
| padding: 2rem; | |
| } | |
| .block-container { | |
| max-width: 1200px; | |
| padding-left: 5rem; | |
| padding-right: 5rem; | |
| } | |
| h1 { color: #4CAF50; font-family: 'Arial Black', sans-serif; } | |
| .stButton>button { background-color: #4CAF50; color: white; } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| if 'authenticated' not in st.session_state: | |
| st.session_state.authenticated = False | |
| if st.session_state.authenticated: | |
| # Get collections after authentication | |
| try: | |
| cols = get_collections() | |
| except Exception as exc: | |
| st.error(f"Configuration error: {exc}") | |
| st.stop() | |
| collection = cols['collection'] | |
| impExp = cols['impExp'] | |
| st.title("🌾 AgriPredict Dashboard") | |
| if st.button("Get Live Data Feed"): | |
| fetch_and_store_data() | |
| view_mode = st.radio("View Mode", ["Statistics", "Plots", "Predictions", "Exim"], horizontal=True, label_visibility="collapsed") | |
| if view_mode == "Plots": | |
| st.sidebar.header("Filters") | |
| selected_period = st.sidebar.selectbox("Select Time Period", ["2 Weeks", "1 Month", "3 Months", "1 Year", "5 Years"], index=1) | |
| period_mapping = {"2 Weeks": 14, "1 Month": 30, "3 Months": 90, "1 Year": 365, "2 Years": 730, "5 Years": 1825} | |
| st.session_state.selected_period = period_mapping[selected_period] | |
| state_options = list(state_market_dict.keys()) + ['India'] | |
| selected_state = st.sidebar.selectbox("Select", state_options) | |
| market_wise = False | |
| if selected_state != 'India': | |
| market_wise = st.sidebar.checkbox("Market Wise Analysis") | |
| if market_wise: | |
| markets = state_market_dict.get(selected_state, []) | |
| selected_market = st.sidebar.selectbox("Select Market", markets) | |
| query_filter = {"State Name": selected_state, "Market Name": selected_market} | |
| else: | |
| query_filter = {"State Name": selected_state} | |
| else: | |
| query_filter = {} | |
| data_type = st.sidebar.radio("Select Data Type", ["Price", "Volume", "Both"]) | |
| if st.sidebar.button("✨ Let's go!"): | |
| try: | |
| # Load data | |
| state_param = selected_state if selected_state != 'India' else None | |
| market_param = selected_market if market_wise else None | |
| df = get_filtered_data(collection, state_param, market_param, st.session_state.selected_period) | |
| if not df.empty: | |
| # Group by date and aggregate | |
| df_grouped = df.groupby('Reported Date', as_index=False).agg({ | |
| 'Arrivals (Tonnes)': 'sum', | |
| 'Modal Price (Rs./Quintal)': 'mean' | |
| }) | |
| # Create complete date range and fill gaps | |
| date_range = pd.date_range( | |
| start=df_grouped['Reported Date'].min(), | |
| end=df_grouped['Reported Date'].max(), | |
| freq='D' | |
| ) | |
| df_grouped = df_grouped.set_index('Reported Date').reindex(date_range).rename_axis('Reported Date').reset_index() | |
| # Fill missing values using the working method | |
| df_grouped['Arrivals (Tonnes)'] = df_grouped['Arrivals (Tonnes)'].ffill().bfill() | |
| df_grouped['Modal Price (Rs./Quintal)'] = df_grouped['Modal Price (Rs./Quintal)'].ffill().bfill() | |
| st.subheader(f"📈 Trends for {selected_state} ({'Market: ' + selected_market if market_wise else 'State'})") | |
| if data_type == "Both": | |
| # Min-Max Scaling | |
| scaler = MinMaxScaler() | |
| df_grouped[['Scaled Price', 'Scaled Arrivals']] = scaler.fit_transform( | |
| df_grouped[['Modal Price (Rs./Quintal)', 'Arrivals (Tonnes)']] | |
| ) | |
| import plotly.graph_objects as go | |
| fig = go.Figure() | |
| fig.add_trace(go.Scatter( | |
| x=df_grouped['Reported Date'], | |
| y=df_grouped['Scaled Price'], | |
| mode='lines', | |
| name='Scaled Price', | |
| line=dict(width=1, color='green'), | |
| text=df_grouped['Modal Price (Rs./Quintal)'], | |
| hovertemplate='Date: %{x}<br>Scaled Price: %{y:.2f}<br>Actual Price: %{text:.2f}<extra></extra>' | |
| )) | |
| fig.add_trace(go.Scatter( | |
| x=df_grouped['Reported Date'], | |
| y=df_grouped['Scaled Arrivals'], | |
| mode='lines', | |
| name='Scaled Arrivals', | |
| line=dict(width=1, color='blue'), | |
| text=df_grouped['Arrivals (Tonnes)'], | |
| hovertemplate='Date: %{x}<br>Scaled Arrivals: %{y:.2f}<br>Actual Arrivals: %{text:.2f}<extra></extra>' | |
| )) | |
| fig.update_layout( | |
| title="Price and Arrivals Trend", | |
| xaxis_title='Date', | |
| yaxis_title='Scaled Values', | |
| template='plotly_white' | |
| ) | |
| st.plotly_chart(fig, use_container_width=True) | |
| elif data_type == "Price": | |
| # Plot Modal Price | |
| import plotly.graph_objects as go | |
| fig = go.Figure() | |
| fig.add_trace(go.Scatter( | |
| x=df_grouped['Reported Date'], | |
| y=df_grouped['Modal Price (Rs./Quintal)'], | |
| mode='lines', | |
| name='Modal Price', | |
| line=dict(width=1, color='green') | |
| )) | |
| fig.update_layout(title="Modal Price Trend", xaxis_title='Date', yaxis_title='Price (/Quintall)', template='plotly_white') | |
| st.plotly_chart(fig, use_container_width=True) | |
| elif data_type == "Volume": | |
| # Plot Arrivals (Tonnes) | |
| import plotly.graph_objects as go | |
| fig = go.Figure() | |
| fig.add_trace(go.Scatter( | |
| x=df_grouped['Reported Date'], | |
| y=df_grouped['Arrivals (Tonnes)'], | |
| mode='lines', | |
| name='Arrivals', | |
| line=dict(width=1, color='blue') | |
| )) | |
| fig.update_layout(title="Arrivals Trend", xaxis_title='Date', yaxis_title='Volume (in Tonnes)', template='plotly_white') | |
| st.plotly_chart(fig, use_container_width=True) | |
| else: | |
| st.warning("⚠️ No data found for the selected filters.") | |
| except Exception as e: | |
| st.error(f"❌ Error fetching data: {e}") | |
| elif view_mode == "Predictions": | |
| st.subheader("📊 Model Analysis") | |
| sub_option = st.radio("Select one of the following", ["India", "States", "Market"], horizontal=True) | |
| sub_timeline = st.radio("Select one of the following horizons", ["14 days", "1 month", "3 month"], horizontal=True) | |
| if sub_option == "States": | |
| states = ["Karnataka", "Madhya Pradesh", "Gujarat", "Uttar Pradesh", "Telangana"] | |
| selected_state = st.selectbox("Select State for Model Training", states) | |
| filter_key = f"state_{selected_state}" | |
| if not IS_PROD and st.button("Train and Forecast"): | |
| query_filter = {"State Name": selected_state} | |
| df = fetch_and_process_data(query_filter) | |
| if df is not None: | |
| if sub_timeline == "14 days": | |
| train_and_forecast(df, filter_key, 14) | |
| elif sub_timeline == "1 month": | |
| train_and_forecast(df, filter_key, 30) | |
| else: | |
| train_and_forecast(df, filter_key, 90) | |
| else: | |
| st.error("❌ No data available for the selected state.") | |
| if st.button("Forecast"): | |
| query_filter = {"State Name": selected_state} | |
| df = fetch_and_process_data(query_filter) | |
| if df is not None: | |
| if sub_timeline == "14 days": | |
| forecast(df, filter_key, 14) | |
| elif sub_timeline == "1 month": | |
| forecast(df, filter_key, 30) | |
| else: | |
| forecast(df, filter_key, 90) | |
| else: | |
| st.error("❌ No data available for the selected state.") | |
| elif sub_option == "Market": | |
| market_options = ["Rajkot", "Gondal", "Kalburgi", "Amreli"] | |
| selected_market = st.selectbox("Select Market for Model Training", market_options) | |
| filter_key = f"market_{selected_market}" | |
| if not IS_PROD and st.button("Train and Forecast"): | |
| query_filter = {"Market Name": selected_market} | |
| df = fetch_and_process_data(query_filter) | |
| if df is not None: | |
| if sub_timeline == "14 days": | |
| train_and_forecast(df, filter_key, 14) | |
| elif sub_timeline == "1 month": | |
| train_and_forecast(df, filter_key, 30) | |
| else: | |
| train_and_forecast(df, filter_key, 90) | |
| else: | |
| st.error("❌ No data available for the selected market.") | |
| elif st.button("Forecast"): | |
| query_filter = {"Market Name": selected_market} | |
| df = fetch_and_process_data(query_filter) | |
| if df is not None: | |
| if sub_timeline == "14 days": | |
| forecast(df, filter_key, 14) | |
| elif sub_timeline == "1 month": | |
| forecast(df, filter_key, 30) | |
| else: | |
| forecast(df, filter_key, 90) | |
| else: | |
| st.error("❌ No data available for the selected market.") | |
| elif sub_option == "India": | |
| df = collection_to_dataframe(impExp) | |
| if not IS_PROD and st.button("Train and Forecast"): | |
| query_filter = {} | |
| df = fetch_and_process_data(query_filter) | |
| if df is not None: | |
| if sub_timeline == "14 days": | |
| train_and_forecast(df, "India", 14) | |
| elif sub_timeline == "1 month": | |
| train_and_forecast(df, "India", 30) | |
| else: | |
| train_and_forecast(df, "India", 90) | |
| else: | |
| st.error("❌ No data available for forecasting.") | |
| if st.button("Forecast"): | |
| query_filter = {} | |
| df = fetch_and_process_data(query_filter) | |
| if df is not None: | |
| if sub_timeline == "14 days": | |
| forecast(df, "India", 14) | |
| elif sub_timeline == "1 month": | |
| forecast(df, "India", 30) | |
| else: | |
| forecast(df, "India", 90) | |
| else: | |
| st.error("❌ No data available for forecasting.") | |
| elif view_mode == "Statistics": | |
| # Use cached data loading | |
| df = load_all_data(collection) | |
| if not df.empty: | |
| from src.agri_predict.plotting import display_statistics | |
| display_statistics(df) | |
| else: | |
| st.warning("No data available to display statistics.") | |
| elif view_mode == "Exim": | |
| df = collection_to_dataframe(impExp) | |
| plot_option = st.radio("Select the data to visualize:", ["Import Price", "Import Quantity", "Export Price", "Export Quantity"], horizontal=True) | |
| time_period = st.selectbox("Select time period:", ["1 Month", "6 Months", "1 Year", "2 Years"]) | |
| df["Reported Date"] = pd.to_datetime(df["Reported Date"], format="%Y-%m-%d") | |
| if time_period == "1 Month": | |
| start_date = pd.Timestamp.now() - pd.DateOffset(months=1) | |
| elif time_period == "6 Months": | |
| start_date = pd.Timestamp.now() - pd.DateOffset(months=6) | |
| elif time_period == "1 Year": | |
| start_date = pd.Timestamp.now() - pd.DateOffset(years=1) | |
| else: | |
| start_date = pd.Timestamp.now() - pd.DateOffset(years=2) | |
| filtered_df = df[df["Reported Date"] >= start_date] | |
| if plot_option == "Import Price": | |
| grouped_df = filtered_df.groupby("Reported Date", as_index=False)["VALUE_IMPORT"].mean().rename(columns={"VALUE_IMPORT": "Average Import Price"}) | |
| y_axis_label = "Average Import Price (Rs.)" | |
| elif plot_option == "Import Quantity": | |
| grouped_df = filtered_df.groupby("Reported Date", as_index=False)["QUANTITY_IMPORT"].sum().rename(columns={"QUANTITY_IMPORT": "Total Import Quantity"}) | |
| y_axis_label = "Total Import Quantity (Tonnes)" | |
| elif plot_option == "Export Price": | |
| grouped_df = filtered_df.groupby("Reported Date", as_index=False)["VALUE_EXPORT"].mean().rename(columns={"VALUE_EXPORT": "Average Export Price"}) | |
| y_axis_label = "Average Export Price (Rs.)" | |
| else: | |
| grouped_df = filtered_df.groupby("Reported Date", as_index=False)["QUANTITY_IMPORT"].sum().rename(columns={"QUANTITY_IMPORT": "Total Export Quantity"}) | |
| y_axis_label = "Total Export Quantity (Tonnes)" | |
| import plotly.express as px | |
| fig = px.line(grouped_df, x="Reported Date", y=grouped_df.columns[1], title=f"{plot_option} Over Time", labels={"Reported Date": "Date", grouped_df.columns[1]: y_axis_label}) | |
| st.plotly_chart(fig) | |
| else: | |
| with st.form("login_form"): | |
| st.subheader("Please log in") | |
| username = st.text_input("Username") | |
| password = st.text_input("Password", type="password") | |
| login_button = st.form_submit_button("Login") | |
| if login_button: | |
| # Get collections for authentication | |
| try: | |
| cols = get_collections() | |
| users_collection = cols['users_collection'] | |
| except Exception as exc: | |
| st.error(f"Database connection error: {exc}") | |
| st.stop() | |
| if authenticate_user(username, password, users_collection): | |
| st.session_state.authenticated = True | |
| st.session_state['username'] = username | |
| st.write("Login successful!") | |
| st.rerun() | |
| else: | |
| st.error("Invalid username or password") | |