File size: 3,590 Bytes
10caf6a
8b82ee3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7407a4e
8b82ee3
7407a4e
8b82ee3
 
 
 
 
 
b887aff
8b82ee3
 
 
 
b887aff
8b82ee3
 
7407a4e
8b82ee3
0ead373
8b82ee3
 
 
376f47b
8b82ee3
 
 
 
dab4e1c
8b82ee3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
558f9a2
8b82ee3
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import pandas as pd
import matplotlib.pyplot as plt
import joblib
import gradio as gr
from dateutil.relativedelta import relativedelta
import calendar

def load_model():
    try:
        model = joblib.load('arima_sales_model.pkl')
        return model, None
    except Exception as e:
        return None, f"Failed to load model: {str(e)}"

def parse_date(date_str):
    """Parse the custom date format 'Month-Year'."""
    try:
        date = pd.to_datetime(date_str, format="%B-%Y")
        _, last_day = calendar.monthrange(date.year, date.month)
        start_date = date.replace(day=1)
        end_date = date.replace(day=last_day)
        return start_date, end_date, None
    except ValueError:
        return None, None, "Date format should be 'Month-Year', e.g., 'January-2024'."

def forecast_sales(uploaded_file, start_date_str, end_date_str):
    if uploaded_file is None:
        return "No file uploaded.", None, "Please upload a file."

    try:
        df = pd.read_csv(uploaded_file)
        if 'Date' not in df.columns or 'Sale' not in df.columns:
            return None, "The uploaded file must contain 'Date' and 'Sale' columns.", "File does not have required columns."
    except Exception as e:
        return None, f"Failed to read the uploaded CSV file: {str(e)}", "Error reading file."

    start_date, _, error = parse_date(start_date_str)
    _, end_date, error_end = parse_date(end_date_str)
    if error or error_end:
        return None, error or error_end, "Invalid date format."

    df['Date'] = pd.to_datetime(df['Date'])
    df = df.rename(columns={'Date': 'ds', 'Sale': 'y'})

    df_filtered = df[(df['ds'] >= start_date) & (df['ds'] <= end_date)]

    arima_model, error = load_model()
    if arima_model is None:
        return None, error, "Failed to load ARIMA model."

    try:
        forecast = arima_model.get_forecast(steps=60)
        forecast_index = pd.date_range(start=end_date, periods=61, freq='D')[1:]
        forecast_df = pd.DataFrame({'Date': forecast_index, 'Sales Forecast': forecast.predicted_mean})
        
        fig, ax = plt.subplots(figsize=(10, 6))
        ax.plot(df_filtered['ds'], df_filtered['y'], label='Actual Sales', color='blue')
        ax.plot(forecast_df['Date'], forecast_df['Sales Forecast'], label='Sales Forecast', color='red', linestyle='--')
        ax.set_xlabel('Date')
        ax.set_ylabel('Sales')
        ax.set_title('Sales Forecasting with ARIMA')
        ax.legend()
        return fig, "File loaded and processed successfully."
    except Exception as e:
        return None, f"Failed to generate plot: {str(e)}", "Plotting failed."

def setup_interface():
    with gr.Blocks() as demo:
        gr.Markdown("## MLCast v1.1 - Intelligent Sales Forecasting System")
        with gr.Row():
            with gr.Column(scale=1):
                file_input = gr.File(label="Upload your store data")
                start_date_input = gr.Textbox(label="Start Date", placeholder="January-2024")
                end_date_input = gr.Textbox(label="End Date", placeholder="December-2024")
                forecast_button = gr.Button("Forecast Sales")
            with gr.Column(scale=2):
                output_plot = gr.Plot()
                output_message = gr.Textbox(label="Notifications", visible=True, lines=2)
        forecast_button.click(
            forecast_sales,
            inputs=[file_input, start_date_input, end_date_input],
            outputs=[output_plot, output_message]
        )
    return demo

if __name__ == "__main__":
    interface = setup_interface()
    interface.launch()