Spaces:
Runtime error
Runtime error
nafisneehal
commited on
Upload 3 files
Browse files- Trainer.py +235 -0
- fetch_plot_data.py +81 -0
- gradio_app.py +155 -0
Trainer.py
ADDED
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import hopsworks
|
2 |
+
import pandas as pd
|
3 |
+
import os
|
4 |
+
from datetime import datetime, timedelta
|
5 |
+
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
|
6 |
+
from sklearn.model_selection import train_test_split
|
7 |
+
import joblib
|
8 |
+
from pathlib import Path
|
9 |
+
import hsfs
|
10 |
+
import hsml
|
11 |
+
|
12 |
+
# Define the base directory as the project root
|
13 |
+
BASE_DIR = Path(__file__).resolve().parent.parent.parent
|
14 |
+
|
15 |
+
|
16 |
+
class Trainer:
|
17 |
+
def __init__(self, project_name, feature_group_name, model_registry_name, api_key):
|
18 |
+
self.project_name = project_name
|
19 |
+
self.feature_group_name = feature_group_name
|
20 |
+
self.model_registry_name = model_registry_name
|
21 |
+
self.api_key = api_key
|
22 |
+
self.project = hopsworks.login(api_key_value=self.api_key)
|
23 |
+
self.fs = self.project.get_feature_store()
|
24 |
+
self.model_registry = self.project.get_model_registry()
|
25 |
+
self.feature_view = None
|
26 |
+
self.deployment = None
|
27 |
+
|
28 |
+
def create_feature_view(self):
|
29 |
+
"""Select features from the feature group and create a feature view."""
|
30 |
+
selected_features = self.fs.get_or_create_feature_group(
|
31 |
+
name=self.feature_group_name,
|
32 |
+
version=1
|
33 |
+
).select_all()
|
34 |
+
|
35 |
+
print("Feature group selected successfully......... --->>")
|
36 |
+
|
37 |
+
"""Create or get a feature view for the last 30 days of data."""
|
38 |
+
try:
|
39 |
+
self.feature_view = self.fs.get_or_create_feature_view(
|
40 |
+
name=f"{self.feature_group_name}_view",
|
41 |
+
version=1,
|
42 |
+
description="Feature view with last 30 days of data for model training",
|
43 |
+
query=selected_features,
|
44 |
+
)
|
45 |
+
print("Feature view created or retrieved successfully.")
|
46 |
+
except hsfs.client.exceptions.RestAPIError as e:
|
47 |
+
print(f"Error creating feature view: {e}")
|
48 |
+
|
49 |
+
def delete_feature_view(self):
|
50 |
+
"""Delete the feature view."""
|
51 |
+
try:
|
52 |
+
self.feature_view.delete()
|
53 |
+
print("Feature view deleted successfully.")
|
54 |
+
except hsfs.client.exceptions.RestAPIError as e:
|
55 |
+
print(f"Error deleting feature view: {e}")
|
56 |
+
|
57 |
+
def get_retrain_data_from_feature_view(self):
|
58 |
+
"""Pull the last 30 days of data from the feature view till today."""
|
59 |
+
start_time = datetime.now() - timedelta(days=30)
|
60 |
+
end_time = datetime.now()
|
61 |
+
|
62 |
+
# Get the data as a DataFrame from the feature view
|
63 |
+
df = self.feature_view.get_batch_data(
|
64 |
+
start_time=start_time, end_time=end_time)
|
65 |
+
|
66 |
+
# sort by datetime
|
67 |
+
df = df.sort_values(by='datetime', ascending=False)
|
68 |
+
print("Data pulled from feature view for retraining successfully.")
|
69 |
+
return df
|
70 |
+
|
71 |
+
def get_plot_data_from_feature_view(self, hours):
|
72 |
+
# get last 12 hours of data starting from current hour to plot
|
73 |
+
start_time = datetime.now() - timedelta(hours=hours)
|
74 |
+
end_time = datetime.now()
|
75 |
+
|
76 |
+
# Get the data as a DataFrame from the feature view
|
77 |
+
df = self.feature_view.get_batch_data(
|
78 |
+
start_time=start_time, end_time=end_time)
|
79 |
+
|
80 |
+
# sort by datetime
|
81 |
+
df = df.sort_values(by='datetime', ascending=False)
|
82 |
+
print("Data pulled from feature view for plotting successfully.")
|
83 |
+
return df
|
84 |
+
|
85 |
+
def train_test_split(self, df, test_size=0.2):
|
86 |
+
"""Split data into training and test sets."""
|
87 |
+
# Define feature columns based on lagged features
|
88 |
+
feature_columns = [
|
89 |
+
f"{prefix}_lag_{i}" for i in range(0, 13) for prefix in ["open", "high", "low", "close"]
|
90 |
+
]
|
91 |
+
|
92 |
+
# Separate features and target
|
93 |
+
X = df[feature_columns]
|
94 |
+
y = df['target']
|
95 |
+
|
96 |
+
# Split into train and test sets
|
97 |
+
X_train, X_test, y_train, y_test = train_test_split(
|
98 |
+
X, y, test_size=test_size, random_state=42)
|
99 |
+
print("Data split into train and test sets.")
|
100 |
+
return X_train, X_test, y_train, y_test
|
101 |
+
|
102 |
+
def get_features_labels(self, df):
|
103 |
+
"""Split data into features and labels."""
|
104 |
+
# Define feature columns based on lagged features
|
105 |
+
feature_columns = [
|
106 |
+
f"{prefix}_lag_{i}" for i in range(0, 13) for prefix in ["open", "high", "low", "close"]
|
107 |
+
]
|
108 |
+
|
109 |
+
# Separate features and target
|
110 |
+
X = df[feature_columns]
|
111 |
+
y = df['target']
|
112 |
+
return X, y
|
113 |
+
|
114 |
+
def train_model(self, model, X_train, y_train):
|
115 |
+
"""Train the model on training data."""
|
116 |
+
model.fit(X_train, y_train)
|
117 |
+
print("Model training completed.")
|
118 |
+
return model
|
119 |
+
|
120 |
+
def evaluate_model(self, model, X_test, y_test, **kwargs):
|
121 |
+
"""Evaluate the model on the hold-out test set."""
|
122 |
+
y_pred = model.predict(X_test)
|
123 |
+
# if show_pred in kwargs is true, print the predictions
|
124 |
+
if "show_pred" in kwargs:
|
125 |
+
print(f"Predictions: {y_pred}")
|
126 |
+
mse = mean_squared_error(y_test, y_pred)
|
127 |
+
mae = mean_absolute_error(y_test, y_pred)
|
128 |
+
r2 = r2_score(y_test, y_pred)
|
129 |
+
|
130 |
+
print(f"Model Evaluation:\nMSE: {mse}\nMAE: {mae}\nR2 Score: {r2}")
|
131 |
+
return {"mse": mse, "mae": mae, "r2": r2}
|
132 |
+
|
133 |
+
def save_model_to_registry(self, model, metrics, model_schema, X_train):
|
134 |
+
"""Save the trained model to Hopsworks Model Registry."""
|
135 |
+
# Use BASE_DIR to define the model directory and path
|
136 |
+
model_dir = BASE_DIR / "models"
|
137 |
+
# Ensure the directory exists
|
138 |
+
if not model_dir.exists():
|
139 |
+
model_dir.mkdir(parents=True, exist_ok=True)
|
140 |
+
|
141 |
+
model_path = model_dir / f"{self.model_registry_name}.pkl"
|
142 |
+
joblib.dump(model, model_path)
|
143 |
+
|
144 |
+
new_model = self.model_registry.sklearn.create_model(
|
145 |
+
name=self.model_registry_name,
|
146 |
+
metrics=metrics,
|
147 |
+
model_schema=model_schema,
|
148 |
+
input_example=X_train.sample(),
|
149 |
+
description="Trained model with 30-day feature view data",
|
150 |
+
)
|
151 |
+
|
152 |
+
# Register the model and serve as endpoint
|
153 |
+
new_model.save(str(model_path))
|
154 |
+
# new_model.deploy()
|
155 |
+
print("Model saved to registry successfully.")
|
156 |
+
|
157 |
+
def model_deploy(self):
|
158 |
+
model = self.model_registry.get_model(
|
159 |
+
self.model_registry_name)
|
160 |
+
|
161 |
+
# strip all _ from self.model_registry_name and keep only alphanumeric characters
|
162 |
+
deploy_name = self.model_registry_name.replace("_", "")
|
163 |
+
|
164 |
+
# Get the dataset API for the project
|
165 |
+
dataset_api = self.project.get_dataset_api()
|
166 |
+
|
167 |
+
# Upload the file "predict_example.py" to the "Models" dataset
|
168 |
+
# If a file with the same name already exists, overwrite it
|
169 |
+
predictor_local_path = BASE_DIR / "src" / \
|
170 |
+
"training_pipeline" / "kserve_predict_script.py"
|
171 |
+
uploaded_file_path = dataset_api.upload(
|
172 |
+
predictor_local_path, "Models", overwrite=True)
|
173 |
+
|
174 |
+
# Construct the full path to the uploaded predictor script
|
175 |
+
predictor_script_path = os.path.join(
|
176 |
+
"/Projects", self.project_name, uploaded_file_path)
|
177 |
+
|
178 |
+
self.deployment = model.deploy(
|
179 |
+
name=deploy_name,
|
180 |
+
script_file=predictor_script_path,)
|
181 |
+
|
182 |
+
# start the deployment
|
183 |
+
self.deployment.start()
|
184 |
+
|
185 |
+
def predict_with_hopsworks_api(self, X):
|
186 |
+
"""Use the deployed model to make predictions via the Hopsworks API."""
|
187 |
+
# Get model serving handle from the project
|
188 |
+
model_serving = self.project.get_model_serving()
|
189 |
+
|
190 |
+
model = self.model_registry.get_model(
|
191 |
+
self.model_registry_name, version=1)
|
192 |
+
|
193 |
+
# Ensure the deployment name follows the required regex pattern
|
194 |
+
deploy_name = self.model_registry_name.replace("_", "")
|
195 |
+
|
196 |
+
try:
|
197 |
+
# Get the deployment
|
198 |
+
deployment = model_serving.get_deployment(name=deploy_name)
|
199 |
+
|
200 |
+
# Make predictions
|
201 |
+
predictions = deployment.predict(inputs=X.values.tolist())
|
202 |
+
print("Predictions made via Hopsworks model API.")
|
203 |
+
return predictions
|
204 |
+
except hsml.client.exceptions.RestAPIError as e:
|
205 |
+
print(f"Error making predictions: {e}")
|
206 |
+
return None
|
207 |
+
except Exception as e:
|
208 |
+
print(f"Unexpected error: {e}")
|
209 |
+
return None
|
210 |
+
|
211 |
+
def stop_model_deployment(self):
|
212 |
+
model = self.model_registry.get_model(
|
213 |
+
self.model_registry_name, version=1)
|
214 |
+
# Ensure the deployment name follows the required regex pattern
|
215 |
+
deploy_name = self.model_registry_name.replace("_", "")
|
216 |
+
|
217 |
+
# Get model serving handle
|
218 |
+
model_serving = self.project.get_model_serving()
|
219 |
+
|
220 |
+
try:
|
221 |
+
# List deployments
|
222 |
+
deployments = model_serving.get_deployments(model)
|
223 |
+
for deployment in deployments:
|
224 |
+
if deployment.name == deploy_name:
|
225 |
+
# deployment.stop()
|
226 |
+
deployment.delete(force=True)
|
227 |
+
print(
|
228 |
+
f"Deployment {deploy_name} stopped and deleted successfully.")
|
229 |
+
break
|
230 |
+
else:
|
231 |
+
print(f"No deployment found with name: {deploy_name}")
|
232 |
+
except hsml.client.exceptions.RestAPIError as e:
|
233 |
+
print(f"Error stopping or deleting deployment: {e}")
|
234 |
+
|
235 |
+
return model
|
fetch_plot_data.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pprint
|
2 |
+
from dotenv import load_dotenv
|
3 |
+
import yaml
|
4 |
+
from pathlib import Path
|
5 |
+
from Trainer import Trainer # Assuming Trainer.py is in the same directory
|
6 |
+
import requests
|
7 |
+
import os
|
8 |
+
import json
|
9 |
+
import warnings
|
10 |
+
import pandas as pd
|
11 |
+
import hsml
|
12 |
+
warnings.filterwarnings('ignore')
|
13 |
+
|
14 |
+
load_dotenv()
|
15 |
+
|
16 |
+
# Hopsworks API configuration
|
17 |
+
# Or replace with your actual API key
|
18 |
+
HOPSWORKS_API_KEY = os.getenv("HOPSWORKS_API_KEY")
|
19 |
+
|
20 |
+
# Define the base directory as the project root
|
21 |
+
BASE_DIR = Path(__file__).resolve().parent.parent.parent
|
22 |
+
|
23 |
+
# Use BASE_DIR to dynamically load the config file
|
24 |
+
CONFIG_FILE = BASE_DIR / "src" / "config.yml"
|
25 |
+
with open(CONFIG_FILE, 'r') as file:
|
26 |
+
configs = yaml.safe_load(file)
|
27 |
+
|
28 |
+
|
29 |
+
# Initialize Trainer instance with Hopsworks project configurations
|
30 |
+
symbol = configs['stock_api_params']['symbol']
|
31 |
+
# Initialize Trainer with relevant project details
|
32 |
+
trainer = Trainer(
|
33 |
+
project_name=configs['hopsworks']['project_name'],
|
34 |
+
feature_group_name=f"{symbol.split('/')[0].lower()}_features",
|
35 |
+
model_registry_name=f"{symbol.split('/')[0].lower()}_regressor_model",
|
36 |
+
api_key=os.getenv("HOPSWORKS_API_KEY")
|
37 |
+
)
|
38 |
+
|
39 |
+
|
40 |
+
def return_plot_data(hours):
|
41 |
+
# Create or retrieve feature view
|
42 |
+
trainer.create_feature_view()
|
43 |
+
|
44 |
+
# Get the plot data from the feature view
|
45 |
+
input_df = trainer.get_plot_data_from_feature_view(hours)
|
46 |
+
|
47 |
+
# get the datetime column from the input_df
|
48 |
+
datetime_column = input_df['datetime']
|
49 |
+
|
50 |
+
input_features, input_labels = trainer.get_features_labels(input_df)
|
51 |
+
|
52 |
+
return input_features, input_labels, datetime_column
|
53 |
+
|
54 |
+
|
55 |
+
def return_plot_data_prediction(input_features):
|
56 |
+
# Get the prediction
|
57 |
+
prediction = trainer.predict_with_hopsworks_api(input_features)
|
58 |
+
return prediction
|
59 |
+
|
60 |
+
|
61 |
+
def get_plot_data(hours):
|
62 |
+
# Get the plot data
|
63 |
+
input_features, input_labels, datetime_column = return_plot_data(
|
64 |
+
hours)
|
65 |
+
prediction = return_plot_data_prediction(input_features)
|
66 |
+
return {"features": input_features, "labels": input_labels,
|
67 |
+
"prediction": prediction['predictions'], "datetime": datetime_column}
|
68 |
+
|
69 |
+
|
70 |
+
# f, l, d = return_plot_data()
|
71 |
+
# print(f)
|
72 |
+
# print(l)
|
73 |
+
# print(trainer.predict_with_hopsworks_api(f))
|
74 |
+
|
75 |
+
|
76 |
+
# # Example input data (replace with your actual input structure)
|
77 |
+
# input_ls = [76480.91, 76648.94, 76390.51, 76541.99, 76330.78, 76339.94, 76312.67, 76319.28, 76246.58, 76413.26, 76206.41, 76333.14, 76396.64, 76732.32, 76151.9, 76244.62, 76279.09, 76429.21, 76222.1, 76396.63, 76122.3, 76283.43, 75758.58,
|
78 |
+
# 76272.1, 76349.99, 76366.2, 76093.0, 76117.98, 76395.53, 76456.16, 76319.87, 76348.18, 76461.01, 76481.48, 76300.38, 76395.53, 76330.91, 76517.26, 76323.53, 76461.02, 76532.39, 76583.19, 76319.32, 76330.91, 76509.82, 76570.6, 76415.72, 76534.61]
|
79 |
+
# input_columns = ["open_lag_1", "high_lag_1", "low_lag_1", "close_lag_1", "open_lag_2", "high_lag_2", "low_lag_2", "close_lag_2", "open_lag_3", "high_lag_3", "low_lag_3", "close_lag_3", "open_lag_4", "high_lag_4", "low_lag_4", "close_lag_4", "open_lag_5", "high_lag_5", "low_lag_5", "close_lag_5", "open_lag_6", "high_lag_6", "low_lag_6", "close_lag_6",
|
80 |
+
# "open_lag_7", "high_lag_7", "low_lag_7", "close_lag_7", "open_lag_8", "high_lag_8", "low_lag_8", "close_lag_8", "open_lag_9", "high_lag_9", "low_lag_9", "close_lag_9", "open_lag_10", "high_lag_10", "low_lag_10", "close_lag_10", "open_lag_11", "high_lag_11", "low_lag_11", "close_lag_11", "open_lag_12", "high_lag_12", "low_lag_12", "close_lag_12"]
|
81 |
+
# input_df = pd.DataFrame([input_ls], columns=input_columns)
|
gradio_app.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import pandas as pd
|
3 |
+
import numpy as np
|
4 |
+
from fetch_plot_data import get_plot_data
|
5 |
+
|
6 |
+
|
7 |
+
def get_time_series_data():
|
8 |
+
# Fetch and process data
|
9 |
+
plot_data = get_plot_data(hours=24)
|
10 |
+
plot_data["datetime"] = pd.to_datetime(plot_data["datetime"])
|
11 |
+
time_series_data = pd.DataFrame({
|
12 |
+
"Datetime": plot_data["datetime"],
|
13 |
+
"Actual BTC/USD": plot_data["labels"],
|
14 |
+
"Predicted BTC/USD": plot_data["prediction"]
|
15 |
+
})
|
16 |
+
time_series_data = time_series_data.sort_values(by="Datetime")
|
17 |
+
time_series_data["Datetime"] = time_series_data["Datetime"].dt.strftime(
|
18 |
+
"%Y-%m-%d %H:%M")
|
19 |
+
|
20 |
+
all_values = np.concatenate([time_series_data["Actual BTC/USD"],
|
21 |
+
time_series_data["Predicted BTC/USD"]])
|
22 |
+
y_min = np.min(all_values)
|
23 |
+
y_max = np.max(all_values)
|
24 |
+
y_range = y_max - y_min
|
25 |
+
padding = y_range * 0.0005
|
26 |
+
y_min = y_min - padding
|
27 |
+
y_max = y_max + padding
|
28 |
+
|
29 |
+
long_data = time_series_data.melt(
|
30 |
+
id_vars="Datetime",
|
31 |
+
var_name="Series",
|
32 |
+
value_name="BTC/USD Value"
|
33 |
+
)
|
34 |
+
return (long_data, y_min, y_max)
|
35 |
+
|
36 |
+
|
37 |
+
custom_css = """
|
38 |
+
body {
|
39 |
+
background-color: #f8fafc !important;
|
40 |
+
}
|
41 |
+
|
42 |
+
.gradio-container {
|
43 |
+
max-width: 1200px !important;
|
44 |
+
margin: 2rem auto !important;
|
45 |
+
padding: 2rem !important;
|
46 |
+
background-color: white !important;
|
47 |
+
border-radius: 1rem !important;
|
48 |
+
box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1), 0 2px 4px -1px rgba(0, 0, 0, 0.06) !important;
|
49 |
+
}
|
50 |
+
|
51 |
+
.main-title {
|
52 |
+
color: #1e293b !important;
|
53 |
+
font-size: 2.5rem !important;
|
54 |
+
font-weight: 700 !important;
|
55 |
+
text-align: center !important;
|
56 |
+
margin-bottom: 0.5rem !important;
|
57 |
+
line-height: 1.2 !important;
|
58 |
+
}
|
59 |
+
|
60 |
+
.subtitle {
|
61 |
+
color: #64748b !important;
|
62 |
+
font-size: 1.125rem !important;
|
63 |
+
text-align: center !important;
|
64 |
+
margin-bottom: 1.5rem !important;
|
65 |
+
font-weight: 500 !important;
|
66 |
+
}
|
67 |
+
|
68 |
+
.chart-container {
|
69 |
+
margin-bottom: 1rem !important;
|
70 |
+
}
|
71 |
+
|
72 |
+
.footer-content {
|
73 |
+
margin-top: 1rem !important;
|
74 |
+
padding-top: 1rem !important;
|
75 |
+
border-top: 1px solid #e2e8f0 !important;
|
76 |
+
display: flex !important;
|
77 |
+
justify-content: space-between !important;
|
78 |
+
align-items: center !important;
|
79 |
+
color: #64748b !important;
|
80 |
+
font-size: 0.875rem !important;
|
81 |
+
}
|
82 |
+
|
83 |
+
.footer-left {
|
84 |
+
text-align: left !important;
|
85 |
+
}
|
86 |
+
|
87 |
+
.footer-right {
|
88 |
+
text-align: right !important;
|
89 |
+
}
|
90 |
+
|
91 |
+
.developer-info {
|
92 |
+
color: #3b82f6 !important;
|
93 |
+
font-weight: 500 !important;
|
94 |
+
text-decoration: none !important;
|
95 |
+
transition: color 0.2s !important;
|
96 |
+
}
|
97 |
+
|
98 |
+
.developer-info:hover {
|
99 |
+
color: #2563eb !important;
|
100 |
+
}
|
101 |
+
"""
|
102 |
+
|
103 |
+
# Initialize the Gradio app
|
104 |
+
with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as app:
|
105 |
+
with gr.Column():
|
106 |
+
# Title and subtitle
|
107 |
+
gr.Markdown("""
|
108 |
+
<div class="main-title">Live BTC/USD Time Series Info</div>
|
109 |
+
<div class="subtitle">Predictions served via Hopsworks API</div>
|
110 |
+
""")
|
111 |
+
|
112 |
+
initial_data, initial_y_min, initial_y_max = get_time_series_data()
|
113 |
+
|
114 |
+
# Chart with reduced bottom margin
|
115 |
+
with gr.Column(elem_classes=["chart-container"]):
|
116 |
+
line_plot = gr.LinePlot(
|
117 |
+
value=initial_data,
|
118 |
+
x="Datetime",
|
119 |
+
y="BTC/USD Value",
|
120 |
+
color="Series",
|
121 |
+
title="",
|
122 |
+
y_title="BTC/USD Value",
|
123 |
+
x_title="Time",
|
124 |
+
x_label_angle=45,
|
125 |
+
width=1000,
|
126 |
+
height=450, # Slightly reduced height
|
127 |
+
colors={
|
128 |
+
"Actual BTC/USD": "#3b82f6",
|
129 |
+
"Predicted BTC/USD": "#ef4444"
|
130 |
+
},
|
131 |
+
tooltip=["Datetime", "BTC/USD Value", "Series"],
|
132 |
+
overlay_point=True,
|
133 |
+
zoom=False,
|
134 |
+
pan=False,
|
135 |
+
show_label=True,
|
136 |
+
stroke_width=2,
|
137 |
+
y_min=initial_y_min,
|
138 |
+
y_max=initial_y_max,
|
139 |
+
y_lim=[initial_y_min, initial_y_max],
|
140 |
+
show_grid=True,
|
141 |
+
)
|
142 |
+
|
143 |
+
# Footer with timestamp and developer info
|
144 |
+
gr.Markdown(f"""
|
145 |
+
<div class="footer-content">
|
146 |
+
<div class="footer-left">
|
147 |
+
Last updated: {pd.Timestamp.now().strftime("%Y-%m-%d %H:%M:%S")}
|
148 |
+
<br>
|
149 |
+
<a href="https://nafis-neehal.github.io/" target="_blank" class="developer-info">Developed by Nafis Neehal</a>
|
150 |
+
</div>
|
151 |
+
</div>
|
152 |
+
""")
|
153 |
+
|
154 |
+
# Launch the app
|
155 |
+
app.launch()
|