File size: 3,590 Bytes
10caf6a 8b82ee3 7407a4e 8b82ee3 7407a4e 8b82ee3 b887aff 8b82ee3 b887aff 8b82ee3 7407a4e 8b82ee3 0ead373 8b82ee3 376f47b 8b82ee3 dab4e1c 8b82ee3 558f9a2 8b82ee3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
import pandas as pd
import matplotlib.pyplot as plt
import joblib
import gradio as gr
from dateutil.relativedelta import relativedelta
import calendar
def load_model():
try:
model = joblib.load('arima_sales_model.pkl')
return model, None
except Exception as e:
return None, f"Failed to load model: {str(e)}"
def parse_date(date_str):
"""Parse the custom date format 'Month-Year'."""
try:
date = pd.to_datetime(date_str, format="%B-%Y")
_, last_day = calendar.monthrange(date.year, date.month)
start_date = date.replace(day=1)
end_date = date.replace(day=last_day)
return start_date, end_date, None
except ValueError:
return None, None, "Date format should be 'Month-Year', e.g., 'January-2024'."
def forecast_sales(uploaded_file, start_date_str, end_date_str):
if uploaded_file is None:
return "No file uploaded.", None, "Please upload a file."
try:
df = pd.read_csv(uploaded_file)
if 'Date' not in df.columns or 'Sale' not in df.columns:
return None, "The uploaded file must contain 'Date' and 'Sale' columns.", "File does not have required columns."
except Exception as e:
return None, f"Failed to read the uploaded CSV file: {str(e)}", "Error reading file."
start_date, _, error = parse_date(start_date_str)
_, end_date, error_end = parse_date(end_date_str)
if error or error_end:
return None, error or error_end, "Invalid date format."
df['Date'] = pd.to_datetime(df['Date'])
df = df.rename(columns={'Date': 'ds', 'Sale': 'y'})
df_filtered = df[(df['ds'] >= start_date) & (df['ds'] <= end_date)]
arima_model, error = load_model()
if arima_model is None:
return None, error, "Failed to load ARIMA model."
try:
forecast = arima_model.get_forecast(steps=60)
forecast_index = pd.date_range(start=end_date, periods=61, freq='D')[1:]
forecast_df = pd.DataFrame({'Date': forecast_index, 'Sales Forecast': forecast.predicted_mean})
fig, ax = plt.subplots(figsize=(10, 6))
ax.plot(df_filtered['ds'], df_filtered['y'], label='Actual Sales', color='blue')
ax.plot(forecast_df['Date'], forecast_df['Sales Forecast'], label='Sales Forecast', color='red', linestyle='--')
ax.set_xlabel('Date')
ax.set_ylabel('Sales')
ax.set_title('Sales Forecasting with ARIMA')
ax.legend()
return fig, "File loaded and processed successfully."
except Exception as e:
return None, f"Failed to generate plot: {str(e)}", "Plotting failed."
def setup_interface():
with gr.Blocks() as demo:
gr.Markdown("## MLCast v1.1 - Intelligent Sales Forecasting System")
with gr.Row():
with gr.Column(scale=1):
file_input = gr.File(label="Upload your store data")
start_date_input = gr.Textbox(label="Start Date", placeholder="January-2024")
end_date_input = gr.Textbox(label="End Date", placeholder="December-2024")
forecast_button = gr.Button("Forecast Sales")
with gr.Column(scale=2):
output_plot = gr.Plot()
output_message = gr.Textbox(label="Notifications", visible=True, lines=2)
forecast_button.click(
forecast_sales,
inputs=[file_input, start_date_input, end_date_input],
outputs=[output_plot, output_message]
)
return demo
if __name__ == "__main__":
interface = setup_interface()
interface.launch()
|