Spaces:
Running
Running
Sadjad Alikhani
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -12,16 +12,8 @@ from sklearn.metrics import confusion_matrix
|
|
12 |
import matplotlib.pyplot as plt
|
13 |
import pandas as pd
|
14 |
|
15 |
-
# Paths to the predefined images folder
|
16 |
-
RAW_PATH = os.path.join("images", "raw")
|
17 |
-
EMBEDDINGS_PATH = os.path.join("images", "embeddings")
|
18 |
-
|
19 |
-
# Specific values for percentage of data for training
|
20 |
-
percentage_values = (np.arange(9) + 1)*10
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
|
|
|
25 |
def beam_prediction_task(data_percentage, task_complexity):
|
26 |
# Folder naming convention based on input_type, data_percentage, and task_complexity
|
27 |
raw_folder = f"images/raw_{data_percentage/100:.1f}_{task_complexity}"
|
@@ -92,40 +84,6 @@ def plot_confusion_matrix_beamPred(cm, classes, title, save_path):
|
|
92 |
plt.savefig(save_path)
|
93 |
plt.close()
|
94 |
|
95 |
-
|
96 |
-
#def plot_confusion_matrix_beamPred(cm, classes, title, save_path):
|
97 |
-
# plt.figure(figsize=(8, 6))
|
98 |
-
# plt.imshow(cm, interpolation='nearest', cmap='coolwarm')
|
99 |
-
# plt.title(title)
|
100 |
-
# plt.colorbar()
|
101 |
-
# tick_marks = np.arange(len(classes))
|
102 |
-
# plt.xticks(tick_marks, classes, rotation=45)
|
103 |
-
# plt.yticks(tick_marks, classes)
|
104 |
-
#
|
105 |
-
# plt.tight_layout()
|
106 |
-
# plt.ylabel('True label')
|
107 |
-
# plt.xlabel('Predicted label')
|
108 |
-
# plt.savefig(save_path)
|
109 |
-
# plt.close()
|
110 |
-
|
111 |
-
# Function to compute the average confusion matrix across CSV files in a folder
|
112 |
-
#def compute_average_confusion_matrix(folder):
|
113 |
-
# confusion_matrices = []
|
114 |
-
# for file in os.listdir(folder):
|
115 |
-
# if file.endswith(".csv"):
|
116 |
-
# data = pd.read_csv(os.path.join(folder, file))
|
117 |
-
# y_true = data["Target"]
|
118 |
-
# y_pred = data["Top-1 Prediction"]
|
119 |
-
# num_labels = len(np.unique(y_true))
|
120 |
-
# cm = confusion_matrix(y_true, y_pred, labels=np.arange(num_labels))
|
121 |
-
# confusion_matrices.append(cm)
|
122 |
-
#
|
123 |
-
# if confusion_matrices:
|
124 |
-
# avg_cm = np.mean(confusion_matrices, axis=0)
|
125 |
-
# return avg_cm
|
126 |
-
# else:
|
127 |
-
# return None
|
128 |
-
|
129 |
def compute_average_confusion_matrix(folder):
|
130 |
confusion_matrices = []
|
131 |
max_num_labels = 0
|
@@ -162,10 +120,99 @@ def compute_average_confusion_matrix(folder):
|
|
162 |
else:
|
163 |
return None
|
164 |
|
|
|
165 |
|
166 |
|
|
|
|
|
167 |
|
|
|
|
|
168 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
|
170 |
# Custom class to capture print output
|
171 |
class PrintCapture(io.StringIO):
|
@@ -410,7 +457,7 @@ def process_hdf5_file(uploaded_file, percentage_idx):
|
|
410 |
os.chdir(original_dir)
|
411 |
sys.stdout = sys.__stdout__ # Reset print statements
|
412 |
|
413 |
-
|
414 |
with gr.Blocks(css="""
|
415 |
.slider-container {
|
416 |
display: inline-block;
|
@@ -439,17 +486,35 @@ with gr.Blocks(css="""
|
|
439 |
# Separate Tab for LoS/NLoS Classification Task
|
440 |
with gr.Tab("LoS/NLoS Classification Task"):
|
441 |
gr.Markdown("### LoS/NLoS Classification Task")
|
442 |
-
file_input = gr.File(label="Upload HDF5 Dataset", file_types=[".h5"])
|
443 |
|
444 |
-
|
445 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
446 |
with gr.Row():
|
447 |
raw_img_los = gr.Image(label="Raw Channels", type="pil", width=300, height=300)
|
448 |
embeddings_img_los = gr.Image(label="Embeddings", type="pil", width=300, height=300)
|
449 |
output_textbox = gr.Textbox(label="Console Output", lines=10)
|
450 |
|
451 |
-
#
|
452 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
453 |
|
454 |
# Launch the app
|
455 |
if __name__ == "__main__":
|
|
|
12 |
import matplotlib.pyplot as plt
|
13 |
import pandas as pd
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
+
#################### BEAM PREDICTION #########################}
|
17 |
def beam_prediction_task(data_percentage, task_complexity):
|
18 |
# Folder naming convention based on input_type, data_percentage, and task_complexity
|
19 |
raw_folder = f"images/raw_{data_percentage/100:.1f}_{task_complexity}"
|
|
|
84 |
plt.savefig(save_path)
|
85 |
plt.close()
|
86 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
def compute_average_confusion_matrix(folder):
|
88 |
confusion_matrices = []
|
89 |
max_num_labels = 0
|
|
|
120 |
else:
|
121 |
return None
|
122 |
|
123 |
+
########################## LOS/NLOS CLASSIFICATION #############################3
|
124 |
|
125 |
|
126 |
+
# Paths to the predefined images folder
|
127 |
+
LOS_PATH = "images_LoS"
|
128 |
|
129 |
+
# Define the percentage values
|
130 |
+
percentage_values_los = np.linspace(0.1, 1, 20) * 100 # 20 percentage values
|
131 |
|
132 |
+
# Function to compute confusion matrix and plot it
|
133 |
+
def plot_confusion_matrix_from_csv(csv_file_path, title, save_path):
|
134 |
+
# Load CSV file
|
135 |
+
data = pd.read_csv(csv_file_path)
|
136 |
+
|
137 |
+
# Extract ground truth and predictions
|
138 |
+
y_true = data['ground-truth']
|
139 |
+
y_pred = data['predicted']
|
140 |
+
|
141 |
+
# Compute confusion matrix
|
142 |
+
cm = confusion_matrix(y_true, y_pred)
|
143 |
+
|
144 |
+
# Plot the confusion matrix
|
145 |
+
plt.figure(figsize=(5, 5))
|
146 |
+
plt.imshow(cm, interpolation='nearest', cmap='Blues')
|
147 |
+
plt.title(title)
|
148 |
+
plt.colorbar()
|
149 |
+
plt.xticks([0, 1], labels=['Class 0', 'Class 1'])
|
150 |
+
plt.yticks([0, 1], labels=['Class 0', 'Class 1'])
|
151 |
+
|
152 |
+
# Annotate the confusion matrix
|
153 |
+
thresh = cm.max() / 2
|
154 |
+
for i in range(cm.shape[0]):
|
155 |
+
for j in range(cm.shape[1]):
|
156 |
+
plt.text(j, i, format(cm[i, j], 'd'), ha="center", va="center",
|
157 |
+
color="white" if cm[i, j] > thresh else "black")
|
158 |
+
|
159 |
+
plt.ylabel('True label')
|
160 |
+
plt.xlabel('Predicted label')
|
161 |
+
plt.tight_layout()
|
162 |
+
|
163 |
+
# Save the plot as an image
|
164 |
+
plt.savefig(save_path)
|
165 |
+
plt.close()
|
166 |
+
|
167 |
+
# Return the saved image
|
168 |
+
return Image.open(save_path)
|
169 |
+
|
170 |
+
# Function to load confusion matrix based on percentage and input_type
|
171 |
+
def display_confusion_matrices_los(percentage_idx):
|
172 |
+
percentage = percentage_values_los[percentage_idx]
|
173 |
+
|
174 |
+
# Construct folder names
|
175 |
+
raw_folder = os.path.join(LOS_PATH, f"raw_{percentage/100:.3f}_los_noTraining")
|
176 |
+
embeddings_folder = os.path.join(LOS_PATH, f"embedding_{percentage/100:.3f}_los_noTraining")
|
177 |
+
|
178 |
+
# Process raw confusion matrix
|
179 |
+
raw_csv_file = os.path.join(raw_folder, "confusion_matrix.csv")
|
180 |
+
raw_cm_img_path = os.path.join(raw_folder, "confusion_matrix_raw.png")
|
181 |
+
raw_img = plot_confusion_matrix_from_csv(raw_csv_file,
|
182 |
+
f"Raw Confusion Matrix ({percentage:.1f}% data)",
|
183 |
+
raw_cm_img_path)
|
184 |
+
|
185 |
+
# Process embeddings confusion matrix
|
186 |
+
embeddings_csv_file = os.path.join(embeddings_folder, "confusion_matrix.csv")
|
187 |
+
embeddings_cm_img_path = os.path.join(embeddings_folder, "confusion_matrix_embeddings.png")
|
188 |
+
embeddings_img = plot_confusion_matrix_from_csv(embeddings_csv_file,
|
189 |
+
f"Embeddings Confusion Matrix ({percentage:.1f}% data)",
|
190 |
+
embeddings_cm_img_path)
|
191 |
+
|
192 |
+
return raw_img, embeddings_img
|
193 |
+
|
194 |
+
# Main function to handle user choice
|
195 |
+
def handle_user_choice(choice, percentage_idx=None, uploaded_file=None):
|
196 |
+
if choice == "Use Predefined Data":
|
197 |
+
return display_confusion_matrices_los(percentage_idx)
|
198 |
+
elif choice == "Upload Dataset":
|
199 |
+
if uploaded_file is not None:
|
200 |
+
return process_hdf5_file(uploaded_file, percentage_idx)
|
201 |
+
else:
|
202 |
+
return "Please upload a dataset", "Please upload a dataset"
|
203 |
+
else:
|
204 |
+
return "Invalid choice", "Invalid choice"
|
205 |
+
|
206 |
+
|
207 |
+
|
208 |
+
|
209 |
+
|
210 |
+
|
211 |
+
|
212 |
+
|
213 |
+
|
214 |
+
|
215 |
+
|
216 |
|
217 |
# Custom class to capture print output
|
218 |
class PrintCapture(io.StringIO):
|
|
|
457 |
os.chdir(original_dir)
|
458 |
sys.stdout = sys.__stdout__ # Reset print statements
|
459 |
|
460 |
+
######################## Define the Gradio interface ###############################
|
461 |
with gr.Blocks(css="""
|
462 |
.slider-container {
|
463 |
display: inline-block;
|
|
|
486 |
# Separate Tab for LoS/NLoS Classification Task
|
487 |
with gr.Tab("LoS/NLoS Classification Task"):
|
488 |
gr.Markdown("### LoS/NLoS Classification Task")
|
|
|
489 |
|
490 |
+
# Radio button for user choice: predefined data or upload dataset
|
491 |
+
choice_radio = gr.Radio(choices=["Use Predefined Data", "Upload Dataset"], label="Choose how to proceed", value="Use Predefined Data")
|
492 |
+
|
493 |
+
# Dropdown for selecting percentage for predefined data
|
494 |
+
percentage_dropdown_los = gr.Dropdown(choices=list(range(20)), value=0, label="Percentage of Data for Training")
|
495 |
+
|
496 |
+
# File uploader for dataset (only visible if user chooses to upload a dataset)
|
497 |
+
file_input = gr.File(label="Upload HDF5 Dataset", file_types=[".h5"], visible=False)
|
498 |
+
|
499 |
+
# Confusion matrices display
|
500 |
with gr.Row():
|
501 |
raw_img_los = gr.Image(label="Raw Channels", type="pil", width=300, height=300)
|
502 |
embeddings_img_los = gr.Image(label="Embeddings", type="pil", width=300, height=300)
|
503 |
output_textbox = gr.Textbox(label="Console Output", lines=10)
|
504 |
|
505 |
+
# Update the file uploader visibility based on user choice
|
506 |
+
def toggle_file_input(choice):
|
507 |
+
return gr.update(visible=(choice == "Upload Dataset"))
|
508 |
+
|
509 |
+
choice_radio.change(fn=toggle_file_input, inputs=[choice_radio], outputs=file_input)
|
510 |
+
|
511 |
+
# When user makes a choice, update the display
|
512 |
+
choice_radio.change(fn=handle_user_choice, inputs=[choice_radio, percentage_dropdown_los, file_input],
|
513 |
+
outputs=[raw_img_los, embeddings_img_los, output_textbox])
|
514 |
+
|
515 |
+
# When percentage slider changes (for predefined data)
|
516 |
+
percentage_dropdown_los.change(fn=handle_user_choice, inputs=[choice_radio, percentage_dropdown_los, file_input],
|
517 |
+
outputs=[raw_img_los, embeddings_img_los, output_textbox])
|
518 |
|
519 |
# Launch the app
|
520 |
if __name__ == "__main__":
|