vincentiusyoshuac commited on
Commit
0dc2bb2
·
verified ·
1 Parent(s): df5944e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -95
app.py CHANGED
@@ -2,146 +2,122 @@ import streamlit as st
2
  import pandas as pd
3
  import numpy as np
4
  import matplotlib.pyplot as plt
5
- from amazon_chronos import ChronosModel, TimeSeriesDataSet
6
- import tempfile
7
- import os
8
 
9
  class TimeSeriesForecaster:
10
- def __init__(self):
11
- self.model = None
12
- self.dataset = None
 
 
 
13
  self.original_series = None
 
14
 
15
  def preprocess_data(self, df, date_column, value_column, context_length=30, prediction_length=7):
16
  """
17
- Persiapkan data time series dari DataFrame
18
  """
19
- # Pastikan data terurut berdasarkan tanggal
20
  df = df.sort_values(by=date_column)
21
 
22
- # Konversi kolom tanggal ke datetime
23
  df[date_column] = pd.to_datetime(df[date_column])
24
 
25
- # Set index ke tanggal
26
  df.set_index(date_column, inplace=True)
27
 
28
- # Ekstrak series numerik
29
- time_series = df[value_column].values
30
 
31
- # Buat dataset Chronos
32
- self.original_series = time_series
33
- self.dataset = TimeSeriesDataSet.from_series(
34
- time_series,
35
- context_length=context_length,
36
- prediction_length=prediction_length
37
- )
38
 
39
- return self.dataset
40
-
41
- def train_model(self, model_id='chronos-t5-small'):
42
- """
43
- Latih model Chronos
44
- """
45
- self.model = ChronosModel.from_pretrained(model_id)
46
- self.model.fit(self.dataset)
47
- return self.model
48
 
49
- def forecast(self, n_samples=100):
50
  """
51
- Lakukan prediksi
52
  """
53
- if not self.model or not self.dataset:
54
- raise ValueError("Model belum dilatih. Latih model terlebih dahulu.")
55
-
56
- forecasts = self.model.predict(self.dataset, num_samples=n_samples)
57
  return forecasts
58
 
59
- def visualize_forecast(self, forecasts):
60
  """
61
- Buat visualisasi prediksi
62
  """
63
  plt.figure(figsize=(12, 6))
64
 
65
- # Plot series asli
66
- plt.plot(self.original_series, label='Data Historis', color='blue')
67
 
68
- # Plot prediksi
69
- forecast_mean = forecasts.mean(axis=0)
70
- forecast_lower = np.percentile(forecasts, 10, axis=0)
71
- forecast_upper = np.percentile(forecasts, 90, axis=0)
72
 
73
- forecast_start = len(self.original_series)
74
- plt.plot(
75
- range(forecast_start, forecast_start + len(forecast_mean)),
76
- forecast_mean,
77
- label='Prediksi Rata-rata',
78
- color='red'
79
- )
80
- plt.fill_between(
81
- range(forecast_start, forecast_start + len(forecast_mean)),
82
- forecast_lower,
83
- forecast_upper,
84
- alpha=0.3,
85
- color='red'
86
- )
87
 
88
- plt.title('Peramalan Time Series dengan Amazon Chronos')
89
- plt.xlabel('Indeks Waktu')
90
- plt.ylabel('Nilai')
91
  plt.legend()
92
 
93
  return plt
94
 
95
  def main():
96
- st.title('🕰️ Time Series Forecasting dengan Amazon Chronos')
97
 
98
- # Sidebar untuk upload dan konfigurasi
99
- st.sidebar.header('Pengaturan Prediksi')
100
 
101
- # Upload file CSV
102
  uploaded_file = st.sidebar.file_uploader(
103
- "Unggah File CSV",
104
  type=['csv'],
105
- help="Pastikan file CSV memiliki kolom tanggal dan nilai numerik"
106
  )
107
 
108
- # Pilihan kolom
109
  if uploaded_file is not None:
110
- # Baca CSV
111
  df = pd.read_csv(uploaded_file)
112
 
113
- # Pilih kolom
114
  date_column = st.sidebar.selectbox(
115
- 'Pilih Kolom Tanggal',
116
  options=df.columns
117
  )
118
  value_column = st.sidebar.selectbox(
119
- 'Pilih Kolom Nilai',
120
  options=[col for col in df.columns if col != date_column]
121
  )
122
 
123
- # Parameter prediksi
124
  context_length = st.sidebar.slider(
125
- 'Panjang Konteks',
126
  min_value=10,
127
  max_value=100,
128
  value=30
129
  )
130
  prediction_length = st.sidebar.slider(
131
- 'Panjang Prediksi',
132
  min_value=1,
133
  max_value=30,
134
  value=7
135
  )
136
 
137
- # Tombol proses
138
- if st.sidebar.button('Lakukan Prediksi'):
139
  try:
140
- # Inisiasi forecaster
141
  forecaster = TimeSeriesForecaster()
142
 
143
- # Preprocessing
144
- dataset = forecaster.preprocess_data(
145
  df,
146
  date_column,
147
  value_column,
@@ -149,33 +125,31 @@ def main():
149
  prediction_length
150
  )
151
 
152
- # Latih model
153
- model = forecaster.train_model()
154
-
155
- # Lakukan prediksi
156
- forecasts = forecaster.forecast()
157
 
158
- # Tampilkan hasil
159
- st.subheader('Visualisasi Prediksi')
160
- plt = forecaster.visualize_forecast(forecasts)
161
  st.pyplot(plt)
162
 
163
- # Tampilkan detail prediksi
164
- forecast_mean = forecasts.mean(axis=0)
165
- forecast_lower = np.percentile(forecasts, 10, axis=0)
166
- forecast_upper = np.percentile(forecasts, 90, axis=0)
 
167
 
168
  prediction_df = pd.DataFrame({
169
- 'Prediksi Rata-rata': forecast_mean,
170
- 'Batas Bawah (10%)': forecast_lower,
171
- 'Batas Atas (90%)': forecast_upper
172
  })
173
 
174
- st.subheader('Detail Prediksi')
175
  st.dataframe(prediction_df)
176
 
177
  except Exception as e:
178
- st.error(f"Terjadi kesalahan: {str(e)}")
179
 
180
  if __name__ == '__main__':
181
  main()
 
2
  import pandas as pd
3
  import numpy as np
4
  import matplotlib.pyplot as plt
5
+ import torch
6
+ from chronos import ChronosPipeline
 
7
 
8
  class TimeSeriesForecaster:
9
+ def __init__(self, model_name="amazon/chronos-t5-small"):
10
+ self.pipeline = ChronosPipeline.from_pretrained(
11
+ model_name,
12
+ device_map="cuda" if torch.cuda.is_available() else "cpu",
13
+ torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
14
+ )
15
  self.original_series = None
16
+ self.context = None
17
 
18
  def preprocess_data(self, df, date_column, value_column, context_length=30, prediction_length=7):
19
  """
20
+ Prepare time series data from DataFrame
21
  """
22
+ # Ensure data is sorted by date
23
  df = df.sort_values(by=date_column)
24
 
25
+ # Convert date column to datetime
26
  df[date_column] = pd.to_datetime(df[date_column])
27
 
28
+ # Set index to date
29
  df.set_index(date_column, inplace=True)
30
 
31
+ # Extract numeric series
32
+ self.original_series = df[value_column].values
33
 
34
+ # Convert to tensor
35
+ self.context = torch.tensor(self.original_series[-context_length:], dtype=torch.float32)
 
 
 
 
 
36
 
37
+ return self.context, context_length
 
 
 
 
 
 
 
 
38
 
39
+ def forecast(self, context, prediction_length=7, num_samples=100):
40
  """
41
+ Perform time series forecasting
42
  """
43
+ forecasts = self.pipeline.predict(context, prediction_length, num_samples=num_samples)
 
 
 
44
  return forecasts
45
 
46
+ def visualize_forecast(self, context, forecasts):
47
  """
48
+ Create visualization of predictions
49
  """
50
  plt.figure(figsize=(12, 6))
51
 
52
+ # Plot original series
53
+ plt.plot(range(len(self.original_series)), self.original_series, label='Historical Data', color='blue')
54
 
55
+ # Calculate forecast statistics
56
+ forecast_np = forecasts[0].numpy()
57
+ low, median, high = np.quantile(forecast_np, [0.1, 0.5, 0.9], axis=0)
 
58
 
59
+ # Plot forecast
60
+ forecast_index = range(len(self.original_series), len(self.original_series) + len(median))
61
+ plt.plot(forecast_index, median, color='red', label='Median Forecast')
62
+ plt.fill_between(forecast_index, low, high, color='red', alpha=0.3, label='80% Prediction Interval')
 
 
 
 
 
 
 
 
 
 
63
 
64
+ plt.title('Time Series Forecasting with Amazon Chronos')
65
+ plt.xlabel('Time Index')
66
+ plt.ylabel('Value')
67
  plt.legend()
68
 
69
  return plt
70
 
71
  def main():
72
+ st.title('🕰️ Time Series Forecasting with Amazon Chronos')
73
 
74
+ # Sidebar for upload and configuration
75
+ st.sidebar.header('Forecast Settings')
76
 
77
+ # Upload CSV file
78
  uploaded_file = st.sidebar.file_uploader(
79
+ "Upload CSV File",
80
  type=['csv'],
81
+ help="Ensure CSV file has date and numeric columns"
82
  )
83
 
84
+ # Column selection and prediction settings
85
  if uploaded_file is not None:
86
+ # Read CSV
87
  df = pd.read_csv(uploaded_file)
88
 
89
+ # Select columns
90
  date_column = st.sidebar.selectbox(
91
+ 'Select Date Column',
92
  options=df.columns
93
  )
94
  value_column = st.sidebar.selectbox(
95
+ 'Select Value Column',
96
  options=[col for col in df.columns if col != date_column]
97
  )
98
 
99
+ # Prediction parameters
100
  context_length = st.sidebar.slider(
101
+ 'Context Length',
102
  min_value=10,
103
  max_value=100,
104
  value=30
105
  )
106
  prediction_length = st.sidebar.slider(
107
+ 'Prediction Length',
108
  min_value=1,
109
  max_value=30,
110
  value=7
111
  )
112
 
113
+ # Process button
114
+ if st.sidebar.button('Perform Forecast'):
115
  try:
116
+ # Initialize forecaster
117
  forecaster = TimeSeriesForecaster()
118
 
119
+ # Preprocess data
120
+ context, _ = forecaster.preprocess_data(
121
  df,
122
  date_column,
123
  value_column,
 
125
  prediction_length
126
  )
127
 
128
+ # Perform forecasting
129
+ forecasts = forecaster.forecast(context, prediction_length)
 
 
 
130
 
131
+ # Visualize results
132
+ st.subheader('Forecast Visualization')
133
+ plt = forecaster.visualize_forecast(context, forecasts)
134
  st.pyplot(plt)
135
 
136
+ # Display forecast details
137
+ forecast_np = forecasts[0].numpy()
138
+ forecast_mean = forecast_np.mean(axis=0)
139
+ forecast_lower = np.percentile(forecast_np, 10, axis=0)
140
+ forecast_upper = np.percentile(forecast_np, 90, axis=0)
141
 
142
  prediction_df = pd.DataFrame({
143
+ 'Mean Forecast': forecast_mean,
144
+ 'Lower Bound (10%)': forecast_lower,
145
+ 'Upper Bound (90%)': forecast_upper
146
  })
147
 
148
+ st.subheader('Forecast Details')
149
  st.dataframe(prediction_df)
150
 
151
  except Exception as e:
152
+ st.error(f"An error occurred: {str(e)}")
153
 
154
  if __name__ == '__main__':
155
  main()