|
import pandas as pd |
|
import matplotlib.pyplot as plt |
|
import joblib |
|
import gradio as gr |
|
from dateutil.relativedelta import relativedelta |
|
import calendar |
|
|
|
def load_model(): |
|
try: |
|
model = joblib.load('arima_sales_model.pkl') |
|
return model, None |
|
except Exception as e: |
|
return None, f"Failed to load model: {str(e)}" |
|
|
|
def parse_date(date_str): |
|
"""Parse the custom date format 'Month-Year'.""" |
|
try: |
|
date = pd.to_datetime(date_str, format="%B-%Y") |
|
_, last_day = calendar.monthrange(date.year, date.month) |
|
start_date = date.replace(day=1) |
|
end_date = date.replace(day=last_day) |
|
return start_date, end_date, None |
|
except ValueError: |
|
return None, None, "Date format should be 'Month-Year', e.g., 'January-2024'." |
|
|
|
def forecast_sales(uploaded_file, start_date_str, end_date_str): |
|
if uploaded_file is None: |
|
return "No file uploaded.", None, "Please upload a file." |
|
|
|
try: |
|
df = pd.read_csv(uploaded_file) |
|
if 'Date' not in df.columns or 'Sale' not in df.columns: |
|
return None, "The uploaded file must contain 'Date' and 'Sale' columns.", "File does not have required columns." |
|
except Exception as e: |
|
return None, f"Failed to read the uploaded CSV file: {str(e)}", "Error reading file." |
|
|
|
start_date, _, error = parse_date(start_date_str) |
|
_, end_date, error_end = parse_date(end_date_str) |
|
if error or error_end: |
|
return None, error or error_end, "Invalid date format." |
|
|
|
df['Date'] = pd.to_datetime(df['Date']) |
|
df = df.rename(columns={'Date': 'ds', 'Sale': 'y'}) |
|
|
|
df_filtered = df[(df['ds'] >= start_date) & (df['ds'] <= end_date)] |
|
|
|
arima_model, error = load_model() |
|
if arima_model is None: |
|
return None, error, "Failed to load ARIMA model." |
|
|
|
try: |
|
forecast = arima_model.get_forecast(steps=60) |
|
forecast_index = pd.date_range(start=end_date, periods=61, freq='D')[1:] |
|
forecast_df = pd.DataFrame({'Date': forecast_index, 'Sales Forecast': forecast.predicted_mean}) |
|
|
|
fig, ax = plt.subplots(figsize=(10, 6)) |
|
ax.plot(df_filtered['ds'], df_filtered['y'], label='Actual Sales', color='blue') |
|
ax.plot(forecast_df['Date'], forecast_df['Sales Forecast'], label='Sales Forecast', color='red', linestyle='--') |
|
ax.set_xlabel('Date') |
|
ax.set_ylabel('Sales') |
|
ax.set_title('Sales Forecasting with ARIMA') |
|
ax.legend() |
|
return fig, "File loaded and processed successfully." |
|
except Exception as e: |
|
return None, f"Failed to generate plot: {str(e)}", "Plotting failed." |
|
|
|
def setup_interface(): |
|
with gr.Blocks() as demo: |
|
gr.Markdown("## MLCast v1.1 - Intelligent Sales Forecasting System") |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
file_input = gr.File(label="Upload your store data") |
|
start_date_input = gr.Textbox(label="Start Date", placeholder="January-2024") |
|
end_date_input = gr.Textbox(label="End Date", placeholder="December-2024") |
|
forecast_button = gr.Button("Forecast Sales") |
|
with gr.Column(scale=2): |
|
output_plot = gr.Plot() |
|
output_message = gr.Textbox(label="Notifications", visible=True, lines=2) |
|
forecast_button.click( |
|
forecast_sales, |
|
inputs=[file_input, start_date_input, end_date_input], |
|
outputs=[output_plot, output_message] |
|
) |
|
return demo |
|
|
|
if __name__ == "__main__": |
|
interface = setup_interface() |
|
interface.launch() |
|
|