nafisneehal commited on
Commit
73ac9f6
·
verified ·
1 Parent(s): 2d1e9e0

Upload 3 files

Browse files
Files changed (3) hide show
  1. Trainer.py +235 -0
  2. fetch_plot_data.py +81 -0
  3. gradio_app.py +155 -0
Trainer.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hopsworks
2
+ import pandas as pd
3
+ import os
4
+ from datetime import datetime, timedelta
5
+ from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
6
+ from sklearn.model_selection import train_test_split
7
+ import joblib
8
+ from pathlib import Path
9
+ import hsfs
10
+ import hsml
11
+
12
+ # Define the base directory as the project root
13
+ BASE_DIR = Path(__file__).resolve().parent.parent.parent
14
+
15
+
16
+ class Trainer:
17
+ def __init__(self, project_name, feature_group_name, model_registry_name, api_key):
18
+ self.project_name = project_name
19
+ self.feature_group_name = feature_group_name
20
+ self.model_registry_name = model_registry_name
21
+ self.api_key = api_key
22
+ self.project = hopsworks.login(api_key_value=self.api_key)
23
+ self.fs = self.project.get_feature_store()
24
+ self.model_registry = self.project.get_model_registry()
25
+ self.feature_view = None
26
+ self.deployment = None
27
+
28
+ def create_feature_view(self):
29
+ """Select features from the feature group and create a feature view."""
30
+ selected_features = self.fs.get_or_create_feature_group(
31
+ name=self.feature_group_name,
32
+ version=1
33
+ ).select_all()
34
+
35
+ print("Feature group selected successfully......... --->>")
36
+
37
+ """Create or get a feature view for the last 30 days of data."""
38
+ try:
39
+ self.feature_view = self.fs.get_or_create_feature_view(
40
+ name=f"{self.feature_group_name}_view",
41
+ version=1,
42
+ description="Feature view with last 30 days of data for model training",
43
+ query=selected_features,
44
+ )
45
+ print("Feature view created or retrieved successfully.")
46
+ except hsfs.client.exceptions.RestAPIError as e:
47
+ print(f"Error creating feature view: {e}")
48
+
49
+ def delete_feature_view(self):
50
+ """Delete the feature view."""
51
+ try:
52
+ self.feature_view.delete()
53
+ print("Feature view deleted successfully.")
54
+ except hsfs.client.exceptions.RestAPIError as e:
55
+ print(f"Error deleting feature view: {e}")
56
+
57
+ def get_retrain_data_from_feature_view(self):
58
+ """Pull the last 30 days of data from the feature view till today."""
59
+ start_time = datetime.now() - timedelta(days=30)
60
+ end_time = datetime.now()
61
+
62
+ # Get the data as a DataFrame from the feature view
63
+ df = self.feature_view.get_batch_data(
64
+ start_time=start_time, end_time=end_time)
65
+
66
+ # sort by datetime
67
+ df = df.sort_values(by='datetime', ascending=False)
68
+ print("Data pulled from feature view for retraining successfully.")
69
+ return df
70
+
71
+ def get_plot_data_from_feature_view(self, hours):
72
+ # get last 12 hours of data starting from current hour to plot
73
+ start_time = datetime.now() - timedelta(hours=hours)
74
+ end_time = datetime.now()
75
+
76
+ # Get the data as a DataFrame from the feature view
77
+ df = self.feature_view.get_batch_data(
78
+ start_time=start_time, end_time=end_time)
79
+
80
+ # sort by datetime
81
+ df = df.sort_values(by='datetime', ascending=False)
82
+ print("Data pulled from feature view for plotting successfully.")
83
+ return df
84
+
85
+ def train_test_split(self, df, test_size=0.2):
86
+ """Split data into training and test sets."""
87
+ # Define feature columns based on lagged features
88
+ feature_columns = [
89
+ f"{prefix}_lag_{i}" for i in range(0, 13) for prefix in ["open", "high", "low", "close"]
90
+ ]
91
+
92
+ # Separate features and target
93
+ X = df[feature_columns]
94
+ y = df['target']
95
+
96
+ # Split into train and test sets
97
+ X_train, X_test, y_train, y_test = train_test_split(
98
+ X, y, test_size=test_size, random_state=42)
99
+ print("Data split into train and test sets.")
100
+ return X_train, X_test, y_train, y_test
101
+
102
+ def get_features_labels(self, df):
103
+ """Split data into features and labels."""
104
+ # Define feature columns based on lagged features
105
+ feature_columns = [
106
+ f"{prefix}_lag_{i}" for i in range(0, 13) for prefix in ["open", "high", "low", "close"]
107
+ ]
108
+
109
+ # Separate features and target
110
+ X = df[feature_columns]
111
+ y = df['target']
112
+ return X, y
113
+
114
+ def train_model(self, model, X_train, y_train):
115
+ """Train the model on training data."""
116
+ model.fit(X_train, y_train)
117
+ print("Model training completed.")
118
+ return model
119
+
120
+ def evaluate_model(self, model, X_test, y_test, **kwargs):
121
+ """Evaluate the model on the hold-out test set."""
122
+ y_pred = model.predict(X_test)
123
+ # if show_pred in kwargs is true, print the predictions
124
+ if "show_pred" in kwargs:
125
+ print(f"Predictions: {y_pred}")
126
+ mse = mean_squared_error(y_test, y_pred)
127
+ mae = mean_absolute_error(y_test, y_pred)
128
+ r2 = r2_score(y_test, y_pred)
129
+
130
+ print(f"Model Evaluation:\nMSE: {mse}\nMAE: {mae}\nR2 Score: {r2}")
131
+ return {"mse": mse, "mae": mae, "r2": r2}
132
+
133
+ def save_model_to_registry(self, model, metrics, model_schema, X_train):
134
+ """Save the trained model to Hopsworks Model Registry."""
135
+ # Use BASE_DIR to define the model directory and path
136
+ model_dir = BASE_DIR / "models"
137
+ # Ensure the directory exists
138
+ if not model_dir.exists():
139
+ model_dir.mkdir(parents=True, exist_ok=True)
140
+
141
+ model_path = model_dir / f"{self.model_registry_name}.pkl"
142
+ joblib.dump(model, model_path)
143
+
144
+ new_model = self.model_registry.sklearn.create_model(
145
+ name=self.model_registry_name,
146
+ metrics=metrics,
147
+ model_schema=model_schema,
148
+ input_example=X_train.sample(),
149
+ description="Trained model with 30-day feature view data",
150
+ )
151
+
152
+ # Register the model and serve as endpoint
153
+ new_model.save(str(model_path))
154
+ # new_model.deploy()
155
+ print("Model saved to registry successfully.")
156
+
157
+ def model_deploy(self):
158
+ model = self.model_registry.get_model(
159
+ self.model_registry_name)
160
+
161
+ # strip all _ from self.model_registry_name and keep only alphanumeric characters
162
+ deploy_name = self.model_registry_name.replace("_", "")
163
+
164
+ # Get the dataset API for the project
165
+ dataset_api = self.project.get_dataset_api()
166
+
167
+ # Upload the file "predict_example.py" to the "Models" dataset
168
+ # If a file with the same name already exists, overwrite it
169
+ predictor_local_path = BASE_DIR / "src" / \
170
+ "training_pipeline" / "kserve_predict_script.py"
171
+ uploaded_file_path = dataset_api.upload(
172
+ predictor_local_path, "Models", overwrite=True)
173
+
174
+ # Construct the full path to the uploaded predictor script
175
+ predictor_script_path = os.path.join(
176
+ "/Projects", self.project_name, uploaded_file_path)
177
+
178
+ self.deployment = model.deploy(
179
+ name=deploy_name,
180
+ script_file=predictor_script_path,)
181
+
182
+ # start the deployment
183
+ self.deployment.start()
184
+
185
+ def predict_with_hopsworks_api(self, X):
186
+ """Use the deployed model to make predictions via the Hopsworks API."""
187
+ # Get model serving handle from the project
188
+ model_serving = self.project.get_model_serving()
189
+
190
+ model = self.model_registry.get_model(
191
+ self.model_registry_name, version=1)
192
+
193
+ # Ensure the deployment name follows the required regex pattern
194
+ deploy_name = self.model_registry_name.replace("_", "")
195
+
196
+ try:
197
+ # Get the deployment
198
+ deployment = model_serving.get_deployment(name=deploy_name)
199
+
200
+ # Make predictions
201
+ predictions = deployment.predict(inputs=X.values.tolist())
202
+ print("Predictions made via Hopsworks model API.")
203
+ return predictions
204
+ except hsml.client.exceptions.RestAPIError as e:
205
+ print(f"Error making predictions: {e}")
206
+ return None
207
+ except Exception as e:
208
+ print(f"Unexpected error: {e}")
209
+ return None
210
+
211
+ def stop_model_deployment(self):
212
+ model = self.model_registry.get_model(
213
+ self.model_registry_name, version=1)
214
+ # Ensure the deployment name follows the required regex pattern
215
+ deploy_name = self.model_registry_name.replace("_", "")
216
+
217
+ # Get model serving handle
218
+ model_serving = self.project.get_model_serving()
219
+
220
+ try:
221
+ # List deployments
222
+ deployments = model_serving.get_deployments(model)
223
+ for deployment in deployments:
224
+ if deployment.name == deploy_name:
225
+ # deployment.stop()
226
+ deployment.delete(force=True)
227
+ print(
228
+ f"Deployment {deploy_name} stopped and deleted successfully.")
229
+ break
230
+ else:
231
+ print(f"No deployment found with name: {deploy_name}")
232
+ except hsml.client.exceptions.RestAPIError as e:
233
+ print(f"Error stopping or deleting deployment: {e}")
234
+
235
+ return model
fetch_plot_data.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pprint
2
+ from dotenv import load_dotenv
3
+ import yaml
4
+ from pathlib import Path
5
+ from Trainer import Trainer # Assuming Trainer.py is in the same directory
6
+ import requests
7
+ import os
8
+ import json
9
+ import warnings
10
+ import pandas as pd
11
+ import hsml
12
+ warnings.filterwarnings('ignore')
13
+
14
+ load_dotenv()
15
+
16
+ # Hopsworks API configuration
17
+ # Or replace with your actual API key
18
+ HOPSWORKS_API_KEY = os.getenv("HOPSWORKS_API_KEY")
19
+
20
+ # Define the base directory as the project root
21
+ BASE_DIR = Path(__file__).resolve().parent.parent.parent
22
+
23
+ # Use BASE_DIR to dynamically load the config file
24
+ CONFIG_FILE = BASE_DIR / "src" / "config.yml"
25
+ with open(CONFIG_FILE, 'r') as file:
26
+ configs = yaml.safe_load(file)
27
+
28
+
29
+ # Initialize Trainer instance with Hopsworks project configurations
30
+ symbol = configs['stock_api_params']['symbol']
31
+ # Initialize Trainer with relevant project details
32
+ trainer = Trainer(
33
+ project_name=configs['hopsworks']['project_name'],
34
+ feature_group_name=f"{symbol.split('/')[0].lower()}_features",
35
+ model_registry_name=f"{symbol.split('/')[0].lower()}_regressor_model",
36
+ api_key=os.getenv("HOPSWORKS_API_KEY")
37
+ )
38
+
39
+
40
+ def return_plot_data(hours):
41
+ # Create or retrieve feature view
42
+ trainer.create_feature_view()
43
+
44
+ # Get the plot data from the feature view
45
+ input_df = trainer.get_plot_data_from_feature_view(hours)
46
+
47
+ # get the datetime column from the input_df
48
+ datetime_column = input_df['datetime']
49
+
50
+ input_features, input_labels = trainer.get_features_labels(input_df)
51
+
52
+ return input_features, input_labels, datetime_column
53
+
54
+
55
+ def return_plot_data_prediction(input_features):
56
+ # Get the prediction
57
+ prediction = trainer.predict_with_hopsworks_api(input_features)
58
+ return prediction
59
+
60
+
61
+ def get_plot_data(hours):
62
+ # Get the plot data
63
+ input_features, input_labels, datetime_column = return_plot_data(
64
+ hours)
65
+ prediction = return_plot_data_prediction(input_features)
66
+ return {"features": input_features, "labels": input_labels,
67
+ "prediction": prediction['predictions'], "datetime": datetime_column}
68
+
69
+
70
+ # f, l, d = return_plot_data()
71
+ # print(f)
72
+ # print(l)
73
+ # print(trainer.predict_with_hopsworks_api(f))
74
+
75
+
76
+ # # Example input data (replace with your actual input structure)
77
+ # input_ls = [76480.91, 76648.94, 76390.51, 76541.99, 76330.78, 76339.94, 76312.67, 76319.28, 76246.58, 76413.26, 76206.41, 76333.14, 76396.64, 76732.32, 76151.9, 76244.62, 76279.09, 76429.21, 76222.1, 76396.63, 76122.3, 76283.43, 75758.58,
78
+ # 76272.1, 76349.99, 76366.2, 76093.0, 76117.98, 76395.53, 76456.16, 76319.87, 76348.18, 76461.01, 76481.48, 76300.38, 76395.53, 76330.91, 76517.26, 76323.53, 76461.02, 76532.39, 76583.19, 76319.32, 76330.91, 76509.82, 76570.6, 76415.72, 76534.61]
79
+ # input_columns = ["open_lag_1", "high_lag_1", "low_lag_1", "close_lag_1", "open_lag_2", "high_lag_2", "low_lag_2", "close_lag_2", "open_lag_3", "high_lag_3", "low_lag_3", "close_lag_3", "open_lag_4", "high_lag_4", "low_lag_4", "close_lag_4", "open_lag_5", "high_lag_5", "low_lag_5", "close_lag_5", "open_lag_6", "high_lag_6", "low_lag_6", "close_lag_6",
80
+ # "open_lag_7", "high_lag_7", "low_lag_7", "close_lag_7", "open_lag_8", "high_lag_8", "low_lag_8", "close_lag_8", "open_lag_9", "high_lag_9", "low_lag_9", "close_lag_9", "open_lag_10", "high_lag_10", "low_lag_10", "close_lag_10", "open_lag_11", "high_lag_11", "low_lag_11", "close_lag_11", "open_lag_12", "high_lag_12", "low_lag_12", "close_lag_12"]
81
+ # input_df = pd.DataFrame([input_ls], columns=input_columns)
gradio_app.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ import numpy as np
4
+ from fetch_plot_data import get_plot_data
5
+
6
+
7
+ def get_time_series_data():
8
+ # Fetch and process data
9
+ plot_data = get_plot_data(hours=24)
10
+ plot_data["datetime"] = pd.to_datetime(plot_data["datetime"])
11
+ time_series_data = pd.DataFrame({
12
+ "Datetime": plot_data["datetime"],
13
+ "Actual BTC/USD": plot_data["labels"],
14
+ "Predicted BTC/USD": plot_data["prediction"]
15
+ })
16
+ time_series_data = time_series_data.sort_values(by="Datetime")
17
+ time_series_data["Datetime"] = time_series_data["Datetime"].dt.strftime(
18
+ "%Y-%m-%d %H:%M")
19
+
20
+ all_values = np.concatenate([time_series_data["Actual BTC/USD"],
21
+ time_series_data["Predicted BTC/USD"]])
22
+ y_min = np.min(all_values)
23
+ y_max = np.max(all_values)
24
+ y_range = y_max - y_min
25
+ padding = y_range * 0.0005
26
+ y_min = y_min - padding
27
+ y_max = y_max + padding
28
+
29
+ long_data = time_series_data.melt(
30
+ id_vars="Datetime",
31
+ var_name="Series",
32
+ value_name="BTC/USD Value"
33
+ )
34
+ return (long_data, y_min, y_max)
35
+
36
+
37
+ custom_css = """
38
+ body {
39
+ background-color: #f8fafc !important;
40
+ }
41
+
42
+ .gradio-container {
43
+ max-width: 1200px !important;
44
+ margin: 2rem auto !important;
45
+ padding: 2rem !important;
46
+ background-color: white !important;
47
+ border-radius: 1rem !important;
48
+ box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1), 0 2px 4px -1px rgba(0, 0, 0, 0.06) !important;
49
+ }
50
+
51
+ .main-title {
52
+ color: #1e293b !important;
53
+ font-size: 2.5rem !important;
54
+ font-weight: 700 !important;
55
+ text-align: center !important;
56
+ margin-bottom: 0.5rem !important;
57
+ line-height: 1.2 !important;
58
+ }
59
+
60
+ .subtitle {
61
+ color: #64748b !important;
62
+ font-size: 1.125rem !important;
63
+ text-align: center !important;
64
+ margin-bottom: 1.5rem !important;
65
+ font-weight: 500 !important;
66
+ }
67
+
68
+ .chart-container {
69
+ margin-bottom: 1rem !important;
70
+ }
71
+
72
+ .footer-content {
73
+ margin-top: 1rem !important;
74
+ padding-top: 1rem !important;
75
+ border-top: 1px solid #e2e8f0 !important;
76
+ display: flex !important;
77
+ justify-content: space-between !important;
78
+ align-items: center !important;
79
+ color: #64748b !important;
80
+ font-size: 0.875rem !important;
81
+ }
82
+
83
+ .footer-left {
84
+ text-align: left !important;
85
+ }
86
+
87
+ .footer-right {
88
+ text-align: right !important;
89
+ }
90
+
91
+ .developer-info {
92
+ color: #3b82f6 !important;
93
+ font-weight: 500 !important;
94
+ text-decoration: none !important;
95
+ transition: color 0.2s !important;
96
+ }
97
+
98
+ .developer-info:hover {
99
+ color: #2563eb !important;
100
+ }
101
+ """
102
+
103
+ # Initialize the Gradio app
104
+ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as app:
105
+ with gr.Column():
106
+ # Title and subtitle
107
+ gr.Markdown("""
108
+ <div class="main-title">Live BTC/USD Time Series Info</div>
109
+ <div class="subtitle">Predictions served via Hopsworks API</div>
110
+ """)
111
+
112
+ initial_data, initial_y_min, initial_y_max = get_time_series_data()
113
+
114
+ # Chart with reduced bottom margin
115
+ with gr.Column(elem_classes=["chart-container"]):
116
+ line_plot = gr.LinePlot(
117
+ value=initial_data,
118
+ x="Datetime",
119
+ y="BTC/USD Value",
120
+ color="Series",
121
+ title="",
122
+ y_title="BTC/USD Value",
123
+ x_title="Time",
124
+ x_label_angle=45,
125
+ width=1000,
126
+ height=450, # Slightly reduced height
127
+ colors={
128
+ "Actual BTC/USD": "#3b82f6",
129
+ "Predicted BTC/USD": "#ef4444"
130
+ },
131
+ tooltip=["Datetime", "BTC/USD Value", "Series"],
132
+ overlay_point=True,
133
+ zoom=False,
134
+ pan=False,
135
+ show_label=True,
136
+ stroke_width=2,
137
+ y_min=initial_y_min,
138
+ y_max=initial_y_max,
139
+ y_lim=[initial_y_min, initial_y_max],
140
+ show_grid=True,
141
+ )
142
+
143
+ # Footer with timestamp and developer info
144
+ gr.Markdown(f"""
145
+ <div class="footer-content">
146
+ <div class="footer-left">
147
+ Last updated: {pd.Timestamp.now().strftime("%Y-%m-%d %H:%M:%S")}
148
+ <br>
149
+ <a href="https://nafis-neehal.github.io/" target="_blank" class="developer-info">Developed by Nafis Neehal</a>
150
+ </div>
151
+ </div>
152
+ """)
153
+
154
+ # Launch the app
155
+ app.launch()