Spaces:
Sleeping
Sleeping
vincentiusyoshuac
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -2,146 +2,122 @@ import streamlit as st
|
|
2 |
import pandas as pd
|
3 |
import numpy as np
|
4 |
import matplotlib.pyplot as plt
|
5 |
-
|
6 |
-
import
|
7 |
-
import os
|
8 |
|
9 |
class TimeSeriesForecaster:
|
10 |
-
def __init__(self):
|
11 |
-
self.
|
12 |
-
|
|
|
|
|
|
|
13 |
self.original_series = None
|
|
|
14 |
|
15 |
def preprocess_data(self, df, date_column, value_column, context_length=30, prediction_length=7):
|
16 |
"""
|
17 |
-
|
18 |
"""
|
19 |
-
#
|
20 |
df = df.sort_values(by=date_column)
|
21 |
|
22 |
-
#
|
23 |
df[date_column] = pd.to_datetime(df[date_column])
|
24 |
|
25 |
-
# Set index
|
26 |
df.set_index(date_column, inplace=True)
|
27 |
|
28 |
-
#
|
29 |
-
|
30 |
|
31 |
-
#
|
32 |
-
self.original_series =
|
33 |
-
self.dataset = TimeSeriesDataSet.from_series(
|
34 |
-
time_series,
|
35 |
-
context_length=context_length,
|
36 |
-
prediction_length=prediction_length
|
37 |
-
)
|
38 |
|
39 |
-
return self.
|
40 |
-
|
41 |
-
def train_model(self, model_id='chronos-t5-small'):
|
42 |
-
"""
|
43 |
-
Latih model Chronos
|
44 |
-
"""
|
45 |
-
self.model = ChronosModel.from_pretrained(model_id)
|
46 |
-
self.model.fit(self.dataset)
|
47 |
-
return self.model
|
48 |
|
49 |
-
def forecast(self,
|
50 |
"""
|
51 |
-
|
52 |
"""
|
53 |
-
|
54 |
-
raise ValueError("Model belum dilatih. Latih model terlebih dahulu.")
|
55 |
-
|
56 |
-
forecasts = self.model.predict(self.dataset, num_samples=n_samples)
|
57 |
return forecasts
|
58 |
|
59 |
-
def visualize_forecast(self, forecasts):
|
60 |
"""
|
61 |
-
|
62 |
"""
|
63 |
plt.figure(figsize=(12, 6))
|
64 |
|
65 |
-
# Plot series
|
66 |
-
plt.plot(self.original_series, label='Data
|
67 |
|
68 |
-
#
|
69 |
-
|
70 |
-
|
71 |
-
forecast_upper = np.percentile(forecasts, 90, axis=0)
|
72 |
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
label='Prediksi Rata-rata',
|
78 |
-
color='red'
|
79 |
-
)
|
80 |
-
plt.fill_between(
|
81 |
-
range(forecast_start, forecast_start + len(forecast_mean)),
|
82 |
-
forecast_lower,
|
83 |
-
forecast_upper,
|
84 |
-
alpha=0.3,
|
85 |
-
color='red'
|
86 |
-
)
|
87 |
|
88 |
-
plt.title('
|
89 |
-
plt.xlabel('
|
90 |
-
plt.ylabel('
|
91 |
plt.legend()
|
92 |
|
93 |
return plt
|
94 |
|
95 |
def main():
|
96 |
-
st.title('🕰️ Time Series Forecasting
|
97 |
|
98 |
-
# Sidebar
|
99 |
-
st.sidebar.header('
|
100 |
|
101 |
-
# Upload file
|
102 |
uploaded_file = st.sidebar.file_uploader(
|
103 |
-
"
|
104 |
type=['csv'],
|
105 |
-
help="
|
106 |
)
|
107 |
|
108 |
-
#
|
109 |
if uploaded_file is not None:
|
110 |
-
#
|
111 |
df = pd.read_csv(uploaded_file)
|
112 |
|
113 |
-
#
|
114 |
date_column = st.sidebar.selectbox(
|
115 |
-
'
|
116 |
options=df.columns
|
117 |
)
|
118 |
value_column = st.sidebar.selectbox(
|
119 |
-
'
|
120 |
options=[col for col in df.columns if col != date_column]
|
121 |
)
|
122 |
|
123 |
-
#
|
124 |
context_length = st.sidebar.slider(
|
125 |
-
'
|
126 |
min_value=10,
|
127 |
max_value=100,
|
128 |
value=30
|
129 |
)
|
130 |
prediction_length = st.sidebar.slider(
|
131 |
-
'
|
132 |
min_value=1,
|
133 |
max_value=30,
|
134 |
value=7
|
135 |
)
|
136 |
|
137 |
-
#
|
138 |
-
if st.sidebar.button('
|
139 |
try:
|
140 |
-
#
|
141 |
forecaster = TimeSeriesForecaster()
|
142 |
|
143 |
-
#
|
144 |
-
|
145 |
df,
|
146 |
date_column,
|
147 |
value_column,
|
@@ -149,33 +125,31 @@ def main():
|
|
149 |
prediction_length
|
150 |
)
|
151 |
|
152 |
-
#
|
153 |
-
|
154 |
-
|
155 |
-
# Lakukan prediksi
|
156 |
-
forecasts = forecaster.forecast()
|
157 |
|
158 |
-
#
|
159 |
-
st.subheader('
|
160 |
-
plt = forecaster.visualize_forecast(forecasts)
|
161 |
st.pyplot(plt)
|
162 |
|
163 |
-
#
|
164 |
-
|
165 |
-
|
166 |
-
|
|
|
167 |
|
168 |
prediction_df = pd.DataFrame({
|
169 |
-
'
|
170 |
-
'
|
171 |
-
'
|
172 |
})
|
173 |
|
174 |
-
st.subheader('
|
175 |
st.dataframe(prediction_df)
|
176 |
|
177 |
except Exception as e:
|
178 |
-
st.error(f"
|
179 |
|
180 |
if __name__ == '__main__':
|
181 |
main()
|
|
|
2 |
import pandas as pd
|
3 |
import numpy as np
|
4 |
import matplotlib.pyplot as plt
|
5 |
+
import torch
|
6 |
+
from chronos import ChronosPipeline
|
|
|
7 |
|
8 |
class TimeSeriesForecaster:
|
9 |
+
def __init__(self, model_name="amazon/chronos-t5-small"):
|
10 |
+
self.pipeline = ChronosPipeline.from_pretrained(
|
11 |
+
model_name,
|
12 |
+
device_map="cuda" if torch.cuda.is_available() else "cpu",
|
13 |
+
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
|
14 |
+
)
|
15 |
self.original_series = None
|
16 |
+
self.context = None
|
17 |
|
18 |
def preprocess_data(self, df, date_column, value_column, context_length=30, prediction_length=7):
|
19 |
"""
|
20 |
+
Prepare time series data from DataFrame
|
21 |
"""
|
22 |
+
# Ensure data is sorted by date
|
23 |
df = df.sort_values(by=date_column)
|
24 |
|
25 |
+
# Convert date column to datetime
|
26 |
df[date_column] = pd.to_datetime(df[date_column])
|
27 |
|
28 |
+
# Set index to date
|
29 |
df.set_index(date_column, inplace=True)
|
30 |
|
31 |
+
# Extract numeric series
|
32 |
+
self.original_series = df[value_column].values
|
33 |
|
34 |
+
# Convert to tensor
|
35 |
+
self.context = torch.tensor(self.original_series[-context_length:], dtype=torch.float32)
|
|
|
|
|
|
|
|
|
|
|
36 |
|
37 |
+
return self.context, context_length
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
+
def forecast(self, context, prediction_length=7, num_samples=100):
|
40 |
"""
|
41 |
+
Perform time series forecasting
|
42 |
"""
|
43 |
+
forecasts = self.pipeline.predict(context, prediction_length, num_samples=num_samples)
|
|
|
|
|
|
|
44 |
return forecasts
|
45 |
|
46 |
+
def visualize_forecast(self, context, forecasts):
|
47 |
"""
|
48 |
+
Create visualization of predictions
|
49 |
"""
|
50 |
plt.figure(figsize=(12, 6))
|
51 |
|
52 |
+
# Plot original series
|
53 |
+
plt.plot(range(len(self.original_series)), self.original_series, label='Historical Data', color='blue')
|
54 |
|
55 |
+
# Calculate forecast statistics
|
56 |
+
forecast_np = forecasts[0].numpy()
|
57 |
+
low, median, high = np.quantile(forecast_np, [0.1, 0.5, 0.9], axis=0)
|
|
|
58 |
|
59 |
+
# Plot forecast
|
60 |
+
forecast_index = range(len(self.original_series), len(self.original_series) + len(median))
|
61 |
+
plt.plot(forecast_index, median, color='red', label='Median Forecast')
|
62 |
+
plt.fill_between(forecast_index, low, high, color='red', alpha=0.3, label='80% Prediction Interval')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
|
64 |
+
plt.title('Time Series Forecasting with Amazon Chronos')
|
65 |
+
plt.xlabel('Time Index')
|
66 |
+
plt.ylabel('Value')
|
67 |
plt.legend()
|
68 |
|
69 |
return plt
|
70 |
|
71 |
def main():
|
72 |
+
st.title('🕰️ Time Series Forecasting with Amazon Chronos')
|
73 |
|
74 |
+
# Sidebar for upload and configuration
|
75 |
+
st.sidebar.header('Forecast Settings')
|
76 |
|
77 |
+
# Upload CSV file
|
78 |
uploaded_file = st.sidebar.file_uploader(
|
79 |
+
"Upload CSV File",
|
80 |
type=['csv'],
|
81 |
+
help="Ensure CSV file has date and numeric columns"
|
82 |
)
|
83 |
|
84 |
+
# Column selection and prediction settings
|
85 |
if uploaded_file is not None:
|
86 |
+
# Read CSV
|
87 |
df = pd.read_csv(uploaded_file)
|
88 |
|
89 |
+
# Select columns
|
90 |
date_column = st.sidebar.selectbox(
|
91 |
+
'Select Date Column',
|
92 |
options=df.columns
|
93 |
)
|
94 |
value_column = st.sidebar.selectbox(
|
95 |
+
'Select Value Column',
|
96 |
options=[col for col in df.columns if col != date_column]
|
97 |
)
|
98 |
|
99 |
+
# Prediction parameters
|
100 |
context_length = st.sidebar.slider(
|
101 |
+
'Context Length',
|
102 |
min_value=10,
|
103 |
max_value=100,
|
104 |
value=30
|
105 |
)
|
106 |
prediction_length = st.sidebar.slider(
|
107 |
+
'Prediction Length',
|
108 |
min_value=1,
|
109 |
max_value=30,
|
110 |
value=7
|
111 |
)
|
112 |
|
113 |
+
# Process button
|
114 |
+
if st.sidebar.button('Perform Forecast'):
|
115 |
try:
|
116 |
+
# Initialize forecaster
|
117 |
forecaster = TimeSeriesForecaster()
|
118 |
|
119 |
+
# Preprocess data
|
120 |
+
context, _ = forecaster.preprocess_data(
|
121 |
df,
|
122 |
date_column,
|
123 |
value_column,
|
|
|
125 |
prediction_length
|
126 |
)
|
127 |
|
128 |
+
# Perform forecasting
|
129 |
+
forecasts = forecaster.forecast(context, prediction_length)
|
|
|
|
|
|
|
130 |
|
131 |
+
# Visualize results
|
132 |
+
st.subheader('Forecast Visualization')
|
133 |
+
plt = forecaster.visualize_forecast(context, forecasts)
|
134 |
st.pyplot(plt)
|
135 |
|
136 |
+
# Display forecast details
|
137 |
+
forecast_np = forecasts[0].numpy()
|
138 |
+
forecast_mean = forecast_np.mean(axis=0)
|
139 |
+
forecast_lower = np.percentile(forecast_np, 10, axis=0)
|
140 |
+
forecast_upper = np.percentile(forecast_np, 90, axis=0)
|
141 |
|
142 |
prediction_df = pd.DataFrame({
|
143 |
+
'Mean Forecast': forecast_mean,
|
144 |
+
'Lower Bound (10%)': forecast_lower,
|
145 |
+
'Upper Bound (90%)': forecast_upper
|
146 |
})
|
147 |
|
148 |
+
st.subheader('Forecast Details')
|
149 |
st.dataframe(prediction_df)
|
150 |
|
151 |
except Exception as e:
|
152 |
+
st.error(f"An error occurred: {str(e)}")
|
153 |
|
154 |
if __name__ == '__main__':
|
155 |
main()
|