eaglelandsonce commited on
Commit
f93ed07
·
verified ·
1 Parent(s): 98a7aa8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +173 -24
app.py CHANGED
@@ -1,7 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  DEFAULT_PREDICT_FILE = "synthetic_breast_cancer_data_withColumn.csv"
2
 
3
  def main():
4
- global feature_columns
5
 
6
  st.title("Patient Treatment Prediction App")
7
  st.write("Upload patient data to train a model and predict treatments based on input data.")
@@ -10,10 +27,20 @@ def main():
10
  uploaded_file = st.file_uploader("Upload a CSV file for training", type="csv")
11
  if uploaded_file is None:
12
  st.write("Using default training data.")
13
- data = pd.read_csv(DEFAULT_TRAIN_FILE)
 
 
 
 
14
  else:
15
- data = pd.read_csv(uploaded_file)
16
- st.write("Training Dataset Preview:", data.head())
 
 
 
 
 
 
17
 
18
  # Check for Treatment column in training data
19
  if 'Treatment' not in data.columns:
@@ -21,7 +48,11 @@ def main():
21
  return
22
 
23
  # Prepare Data
24
- X, y, input_dim, num_classes, feature_columns = preprocess_training_data(data)
 
 
 
 
25
 
26
  # Model Parameters
27
  hidden_dim = st.slider("Hidden Layer Dimension", 10, 100, 50)
@@ -30,37 +61,155 @@ def main():
30
 
31
  # Model training
32
  if st.button("Train Model"):
33
- model, loss_curve = train_model(X, y, input_dim, hidden_dim, num_classes, learning_rate, epochs)
34
- plot_loss_curve(loss_curve)
 
 
 
 
 
35
 
36
  # Upload data for prediction
37
  st.write("Upload new data for prediction (ensure 'Treatment' column is removed if present).")
38
  new_data_file = st.file_uploader("Upload new CSV file for prediction", type="csv")
39
  if new_data_file is None:
40
  st.write("Using default prediction data.")
41
- new_data = pd.read_csv(DEFAULT_PREDICT_FILE)
 
 
 
 
42
  else:
43
- new_data = pd.read_csv(new_data_file)
44
-
 
 
 
 
45
  # Drop 'Treatment' column if it exists
46
  if 'Treatment' in new_data.columns:
47
  st.warning("The 'Treatment' column is present in the prediction data and will be removed.")
48
  new_data = new_data.drop(columns=['Treatment'])
49
-
50
- st.write("Prediction Dataset Preview:", new_data.head())
51
 
52
- if 'model' in locals() and feature_columns is not None:
53
- # Align columns to match training data
54
- new_data_aligned = align_columns(new_data, feature_columns)
55
-
56
- if new_data_aligned is not None:
57
- predictions = predict_treatment(new_data_aligned, model)
 
58
 
59
- # Display Predictions in an Output Box
60
- st.subheader("Predicted Treatment Outcomes")
61
- prediction_output = "\n".join([f"Patient {i+1}: {pred}" for i, pred in enumerate(predictions)])
62
- st.text_area("Prediction Results", prediction_output, height=200)
63
- else:
64
- st.error("Unable to align prediction data to the training feature columns.")
 
 
 
 
 
65
  else:
66
  st.warning("Please train the model first before predicting on new data.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.optim as optim
6
+ import matplotlib.pyplot as plt
7
+ from sklearn.preprocessing import StandardScaler, LabelEncoder
8
+ import numpy as np
9
+
10
+ # Global scaler and label encoder for consistent preprocessing
11
+ scaler = StandardScaler()
12
+ label_encoder = LabelEncoder()
13
+ feature_columns = None # To store feature columns from the training data
14
+ model = None # Declare the model globally for predictions
15
+
16
+ # Preload default files
17
+ DEFAULT_TRAIN_FILE = "patientdata.csv"
18
  DEFAULT_PREDICT_FILE = "synthetic_breast_cancer_data_withColumn.csv"
19
 
20
  def main():
21
+ global feature_columns, model
22
 
23
  st.title("Patient Treatment Prediction App")
24
  st.write("Upload patient data to train a model and predict treatments based on input data.")
 
27
  uploaded_file = st.file_uploader("Upload a CSV file for training", type="csv")
28
  if uploaded_file is None:
29
  st.write("Using default training data.")
30
+ try:
31
+ data = pd.read_csv(DEFAULT_TRAIN_FILE)
32
+ except Exception as e:
33
+ st.error(f"Error loading default training file: {e}")
34
+ return
35
  else:
36
+ try:
37
+ data = pd.read_csv(uploaded_file)
38
+ except Exception as e:
39
+ st.error(f"Error loading uploaded file: {e}")
40
+ return
41
+
42
+ st.write("Training Dataset Preview:")
43
+ st.dataframe(data.head()) # Use st.dataframe for better visibility
44
 
45
  # Check for Treatment column in training data
46
  if 'Treatment' not in data.columns:
 
48
  return
49
 
50
  # Prepare Data
51
+ try:
52
+ X, y, input_dim, num_classes, feature_columns = preprocess_training_data(data)
53
+ except Exception as e:
54
+ st.error(f"Error during data preprocessing: {e}")
55
+ return
56
 
57
  # Model Parameters
58
  hidden_dim = st.slider("Hidden Layer Dimension", 10, 100, 50)
 
61
 
62
  # Model training
63
  if st.button("Train Model"):
64
+ try:
65
+ model, loss_curve = train_model(X, y, input_dim, hidden_dim, num_classes, learning_rate, epochs)
66
+ plot_loss_curve(loss_curve)
67
+ st.success("Model trained successfully!")
68
+ except Exception as e:
69
+ st.error(f"Error during model training: {e}")
70
+ return
71
 
72
  # Upload data for prediction
73
  st.write("Upload new data for prediction (ensure 'Treatment' column is removed if present).")
74
  new_data_file = st.file_uploader("Upload new CSV file for prediction", type="csv")
75
  if new_data_file is None:
76
  st.write("Using default prediction data.")
77
+ try:
78
+ new_data = pd.read_csv(DEFAULT_PREDICT_FILE)
79
+ except Exception as e:
80
+ st.error(f"Error loading default prediction file: {e}")
81
+ return
82
  else:
83
+ try:
84
+ new_data = pd.read_csv(new_data_file)
85
+ except Exception as e:
86
+ st.error(f"Error loading uploaded prediction file: {e}")
87
+ return
88
+
89
  # Drop 'Treatment' column if it exists
90
  if 'Treatment' in new_data.columns:
91
  st.warning("The 'Treatment' column is present in the prediction data and will be removed.")
92
  new_data = new_data.drop(columns=['Treatment'])
 
 
93
 
94
+ st.write("Prediction Dataset Preview:")
95
+ st.dataframe(new_data.head()) # Display new data
96
+
97
+ if model is not None and feature_columns is not None:
98
+ try:
99
+ # Align columns to match training data
100
+ new_data_aligned = align_columns(new_data, feature_columns)
101
 
102
+ if new_data_aligned is not None:
103
+ predictions = predict_treatment(new_data_aligned, model)
104
+
105
+ # Display Predictions in an Output Box
106
+ st.subheader("Predicted Treatment Outcomes")
107
+ prediction_output = "\n".join([f"Patient {i+1}: {pred}" for i, pred in enumerate(predictions)])
108
+ st.text_area("Prediction Results", prediction_output, height=200)
109
+ else:
110
+ st.error("Unable to align prediction data to the training feature columns.")
111
+ except Exception as e:
112
+ st.error(f"Error during prediction: {e}")
113
  else:
114
  st.warning("Please train the model first before predicting on new data.")
115
+
116
+ def preprocess_training_data(data):
117
+ global scaler, label_encoder
118
+
119
+ # Label encode the 'Treatment' target column
120
+ data['Treatment'] = label_encoder.fit_transform(data['Treatment'])
121
+ y = data['Treatment'].values
122
+
123
+ # Encode and standardize feature columns
124
+ X = data.drop('Treatment', axis=1)
125
+ feature_columns = X.columns # Store feature columns for later alignment
126
+ for col in X.select_dtypes(include=['object']).columns:
127
+ X[col] = LabelEncoder().fit_transform(X[col])
128
+
129
+ # Standardize features
130
+ X = scaler.fit_transform(X)
131
+
132
+ return torch.tensor(X, dtype=torch.float32), torch.tensor(y, dtype=torch.long), X.shape[1], len(np.unique(y)), feature_columns
133
+
134
+ def align_columns(new_data, feature_columns):
135
+ try:
136
+ # Ensure the new data has the same columns as the training data
137
+ missing_cols = set(feature_columns) - set(new_data.columns)
138
+ extra_cols = set(new_data.columns) - set(feature_columns)
139
+
140
+ # Remove any extra columns
141
+ new_data = new_data.drop(columns=extra_cols)
142
+
143
+ # Add missing columns with default value 0
144
+ for col in missing_cols:
145
+ new_data[col] = 0
146
+
147
+ # Reorder columns to match the training data
148
+ new_data = new_data[feature_columns]
149
+
150
+ # Encode and standardize feature columns
151
+ for col in new_data.select_dtypes(include=['object']).columns:
152
+ new_data[col] = LabelEncoder().fit_transform(new_data[col])
153
+
154
+ # Scale features
155
+ new_data = scaler.transform(new_data)
156
+
157
+ return torch.tensor(new_data, dtype=torch.float32)
158
+ except Exception as e:
159
+ st.error(f"Error aligning columns: {e}")
160
+ return None
161
+
162
+ def train_model(X, y, input_dim, hidden_dim, num_classes, learning_rate, epochs):
163
+ class SimpleNN(nn.Module):
164
+ def __init__(self, input_dim, hidden_dim, num_classes):
165
+ super(SimpleNN, self).__init__()
166
+ self.fc1 = nn.Linear(input_dim, hidden_dim)
167
+ self.relu = nn.ReLU()
168
+ self.fc2 = nn.Linear(hidden_dim, num_classes)
169
+
170
+ def forward(self, x):
171
+ x = self.fc1(x)
172
+ x = self.relu(x)
173
+ x = self.fc2(x)
174
+ return x
175
+
176
+ model = SimpleNN(input_dim, hidden_dim, num_classes)
177
+ criterion = nn.CrossEntropyLoss()
178
+ optimizer = optim.Adam(model.parameters(), lr=learning_rate)
179
+
180
+ loss_curve = []
181
+ for epoch in range(epochs):
182
+ optimizer.zero_grad()
183
+ outputs = model(X)
184
+ loss = criterion(outputs, y)
185
+ loss.backward()
186
+ optimizer.step()
187
+ loss_curve.append(loss.item())
188
+
189
+ return model, loss_curve
190
+
191
+ def plot_loss_curve(loss_curve):
192
+ plt.figure()
193
+ plt.plot(loss_curve, label="Training Loss")
194
+ plt.xlabel("Epochs")
195
+ plt.ylabel("Loss")
196
+ plt.title("Loss Curve")
197
+ plt.legend()
198
+ plt.tight_layout() # Ensure layout is tight for Streamlit
199
+ st.pyplot(plt)
200
+
201
+ def predict_treatment(new_data, model, batch_size=32):
202
+ model.eval()
203
+ predictions = []
204
+
205
+ with torch.no_grad():
206
+ for i in range(0, new_data.size(0), batch_size):
207
+ batch_data = new_data[i:i + batch_size]
208
+ outputs = model(batch_data)
209
+ _, batch_predictions = torch.max(outputs, 1)
210
+ predictions.extend(batch_predictions.numpy())
211
+
212
+ return label_encoder.inverse_transform(predictions)
213
+
214
+ if __name__ == "__main__":
215
+ main()